diff options
Diffstat (limited to 'runtime/onert/core/src')
381 files changed, 40641 insertions, 12463 deletions
diff --git a/runtime/onert/core/src/backend/BackendContext.cc b/runtime/onert/core/src/backend/BackendContext.cc index bafa36d28..7b36f106d 100644 --- a/runtime/onert/core/src/backend/BackendContext.cc +++ b/runtime/onert/core/src/backend/BackendContext.cc @@ -16,40 +16,10 @@ #include "backend/BackendContext.h" -#include "ir/Operation.h" -#include "backend/IConstantInitializer.h" - namespace onert { namespace backend { -void BackendContext::initialize(const std::vector<OperationInfo> &operation_list, - const std::vector<ir::OperandIndex> &operand_list) -{ - _operation_list = operation_list; - _operand_list = operand_list; -} - -void BackendContext::initConsts() -{ - for (auto &op : _operation_list) - { - constant_initializer->setLayout(op.layout); - _graph->operations().at(op.index).accept(*constant_initializer); - } - - for (auto ind : _operand_list) - { - const auto &obj = _graph->operands().at(ind); - if (obj.isConstant() && !constant_initializer->exist(ind)) - { - constant_initializer->registerDefaultInitializer(ind, obj); - } - } - - constant_initializer->run(); -} - } // namespace backend } // namespace onert diff --git a/runtime/onert/core/src/backend/IConstantInitializer.cc b/runtime/onert/core/src/backend/IConstantInitializer.cc deleted file mode 100644 index 934a42753..000000000 --- a/runtime/onert/core/src/backend/IConstantInitializer.cc +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/IConstantInitializer.h" - -#include <Half.h> - -using float16 = Half; - -namespace onert -{ -namespace backend -{ - -void IConstantInitializer::registerCopyInitializer(const ir::OperandIndex &index, - const ir::Operand &obj) -{ - // For only CONSTANTS - // TODO Add to check if tensor has been allocated - if (!obj.isConstant()) - return; - - const auto type = obj.typeInfo().type(); - using ir::DataType; - - switch (type) - { - case DataType::FLOAT32: - _init_map[index] = copyInit<float>; - break; - case DataType::INT32: - _init_map[index] = copyInit<int32_t>; - break; - case DataType::UINT32: - _init_map[index] = copyInit<uint32_t>; - break; - case DataType::BOOL8: - case DataType::QUANT_UINT8_ASYMM: - _init_map[index] = copyInit<uint8_t>; - break; - case DataType::QUANT_INT8_SYMM: - _init_map[index] = copyInit<int8_t>; - break; - case DataType::FLOAT16: - _init_map[index] = copyInit<float16>; - break; - case DataType::INT64: - _init_map[index] = copyInit<int64_t>; - break; - default: - throw std::runtime_error("Not supported, yet"); - break; - } -} - -void IConstantInitializer::registerPermuteInitializer(const ir::OperandIndex &index, - const ir::Operand &obj) -{ - // For only CONSTANTS - // TODO Add to check if tensor has been allocated - if (!obj.isConstant()) - return; - - const auto type = obj.typeInfo().type(); - using ir::DataType; - using namespace std::placeholders; - - switch (type) - { - case DataType::FLOAT32: - _init_map[index] = std::bind(permuteInit<float>, _1, _2, _current_op_seq_layout); - break; - case DataType::INT32: - _init_map[index] = std::bind(permuteInit<int32_t>, _1, _2, _current_op_seq_layout); - break; - case DataType::UINT32: - _init_map[index] = std::bind(permuteInit<uint32_t>, _1, _2, _current_op_seq_layout); - break; - case DataType::BOOL8: - case DataType::QUANT_UINT8_ASYMM: - _init_map[index] = std::bind(permuteInit<uint8_t>, _1, _2, _current_op_seq_layout); - break; - case DataType::QUANT_INT8_SYMM: - _init_map[index] = std::bind(permuteInit<int8_t>, _1, _2, _current_op_seq_layout); - break; - case DataType::FLOAT16: - _init_map[index] = std::bind(permuteInit<float16>, _1, _2, _current_op_seq_layout); - break; - case DataType::INT64: - _init_map[index] = std::bind(permuteInit<int64_t>, _1, _2, _current_op_seq_layout); - break; - default: - throw std::runtime_error("Not supported, yet"); - break; - } -} - -} // namespace backend -} // namespace onert diff --git a/runtime/onert/core/src/backend/cpu_common/Tensor.cc b/runtime/onert/core/src/backend/IPortableTensor.cc index f34564dd9..066ba0004 100644 --- a/runtime/onert/core/src/backend/cpu_common/Tensor.cc +++ b/runtime/onert/core/src/backend/IPortableTensor.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,30 +14,31 @@ * limitations under the License. */ -#include "backend/cpu_common/Tensor.h" +#include "backend/IPortableTensor.h" namespace onert { namespace backend { -namespace cpu_common -{ -size_t Tensor::calcOffset(const ir::Coordinates &coords) const +// `dynamic_cast` not working across library boundaries on NDK +// With this as a key function, `dynamic_cast` works across dl +IPortableTensor::~IPortableTensor() {} + +size_t IPortableTensor::calcOffset(const ir::Coordinates &coords) const { - size_t rank = num_dimensions(); + auto shape = _info.shape(); + size_t rank = shape.rank(); rank = rank == 0 ? 1 : rank; size_t offset = 0; for (size_t i = 0; i < rank; ++i) { - offset = offset * dimension(i) + coords[i]; + auto dim = shape.rank() == 0 ? 1 : shape.dim(i); + offset = offset * dim + coords[i]; } offset *= sizeOfDataType(data_type()); return offset; } -void Tensor::setShape(const ir::Shape &new_shape) { _info.shape(new_shape); } - -} // namespace cpu_common } // namespace backend } // namespace onert diff --git a/runtime/onert/core/src/backend/ITensor.cc b/runtime/onert/core/src/backend/ITensor.cc index 7127ed93d..1339cb409 100644 --- a/runtime/onert/core/src/backend/ITensor.cc +++ b/runtime/onert/core/src/backend/ITensor.cc @@ -21,14 +21,9 @@ namespace onert namespace backend { -ir::Shape ITensor::getShape() const -{ - onert::ir::Shape shape(num_dimensions()); - for (uint32_t d = 0; d < num_dimensions(); d++) - shape.dim(d) = dimension(d); - - return shape; -} +// `dynamic_cast` not working across library boundaries on NDK +// With this as a key function, `dynamic_cast` works across dl +ITensor::~ITensor() {} } // namespace backend } // namespace onert diff --git a/runtime/onert/core/src/backend/cpu_common/Allocator.cc b/runtime/onert/core/src/backend/basic/Allocator.cc index 0ba444ee6..61214dfad 100644 --- a/runtime/onert/core/src/backend/cpu_common/Allocator.cc +++ b/runtime/onert/core/src/backend/basic/Allocator.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "backend/cpu_common/Allocator.h" +#include "backend/basic/Allocator.h" #include "util/logging.h" @@ -22,7 +22,7 @@ namespace onert { namespace backend { -namespace cpu_common +namespace basic { Allocator::Allocator(uint32_t capacity) @@ -33,6 +33,6 @@ Allocator::Allocator(uint32_t capacity) VERBOSE(ALLOC) << "base pointer: " << static_cast<void *>(_base.get()) << std::endl; } -} // namespace cpu_common +} // namespace basic } // namespace backend } // namespace onert diff --git a/runtime/onert/core/src/backend/basic/BackendContextHelpers.cc b/runtime/onert/core/src/backend/basic/BackendContextHelpers.cc new file mode 100644 index 000000000..c02cc0cf2 --- /dev/null +++ b/runtime/onert/core/src/backend/basic/BackendContextHelpers.cc @@ -0,0 +1,17 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/basic/BackendContextHelpers.h" diff --git a/runtime/onert/core/src/backend/basic/DynamicTensorManager.cc b/runtime/onert/core/src/backend/basic/DynamicTensorManager.cc new file mode 100644 index 000000000..07bcb09ee --- /dev/null +++ b/runtime/onert/core/src/backend/basic/DynamicTensorManager.cc @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/basic/DynamicTensorManager.h" + +#include "util/logging.h" +#include "misc/polymorphic_downcast.h" + +namespace onert +{ +namespace backend +{ +namespace basic +{ + +DynamicTensorManager::DynamicTensorManager(const std::shared_ptr<TensorRegistry> ®) + : _dynamic_mem_mgr{new DynamicMemoryManager()}, _tensors{reg} +{ + // DO NOTHING +} + +void DynamicTensorManager::buildTensor(const ir::OperandIndex &ind, + const ir::OperandInfo &tensor_info, + ir::Layout backend_layout) +{ + assert(_tensors->getNativeTensor(ind) == nullptr); + auto tensor = std::make_unique<Tensor>(tensor_info, backend_layout, _dynamic_mem_mgr.get()); + _tensors->setNativeTensor(ind, std::move(tensor)); +} + +const ITensor *DynamicTensorManager::getRawITensor(ir::OperandIndex ind) +{ + auto ptr = _tensors->getITensor(ind); + assert(ptr); + return ptr; +} + +} // namespace basic +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/cpu_common/MemoryManager.cc b/runtime/onert/core/src/backend/basic/MemoryManager.cc index 8cb9c22ca..48144561b 100644 --- a/runtime/onert/core/src/backend/cpu_common/MemoryManager.cc +++ b/runtime/onert/core/src/backend/basic/MemoryManager.cc @@ -14,18 +14,19 @@ * limitations under the License. */ -#include <backend/cpu_common/MemoryManager.h> +#include <backend/basic/MemoryManager.h> #include <cassert> #include "MemoryPlannerFactory.h" #include "util/ConfigSource.h" +#include "util/logging.h" namespace onert { namespace backend { -namespace cpu_common +namespace basic { MemoryManager::MemoryManager() : _mem_planner{createMemoryPlanner()} @@ -34,20 +35,21 @@ MemoryManager::MemoryManager() : _mem_planner{createMemoryPlanner()} } MemoryManager::MemoryManager(const std::string planner_id) - : _mem_planner{createMemoryPlanner(planner_id)} + : _mem_planner{createMemoryPlanner(planner_id)} { // DO NOTHING } -cpu_common::IMemoryPlanner *MemoryManager::createMemoryPlanner() +basic::IMemoryPlanner<ir::OperandIndex> *MemoryManager::createMemoryPlanner() { auto planner_id = util::getConfigString(util::config::CPU_MEMORY_PLANNER); - return cpu_common::MemoryPlannerFactory::get().create(planner_id); + return basic::MemoryPlannerFactory::get().create(planner_id); } -cpu_common::IMemoryPlanner *MemoryManager::createMemoryPlanner(const std::string planner_id) +basic::IMemoryPlanner<ir::OperandIndex> * +MemoryManager::createMemoryPlanner(const std::string planner_id) { - return cpu_common::MemoryPlannerFactory::get().create(planner_id); + return basic::MemoryPlannerFactory::get().create(planner_id); } void MemoryManager::claimPlan(const ir::OperandIndex &ind, uint32_t size) @@ -59,7 +61,7 @@ void MemoryManager::releasePlan(const ir::OperandIndex &ind) { _mem_planner->rel void MemoryManager::allocate(void) { - _mem_alloc = std::make_shared<cpu_common::Allocator>(_mem_planner->capacity()); + _mem_alloc = std::make_shared<basic::Allocator>(_mem_planner->capacity()); assert(_mem_alloc->base()); } @@ -70,20 +72,20 @@ uint8_t *MemoryManager::getBuffer(const ir::OperandIndex &ind) const return _mem_alloc->base() + mem_blk.offset; } -std::shared_ptr<cpu_common::Allocator> DynamicMemoryManager::allocate(const ir::OperandIndex &ind, - uint32_t capacity) +std::shared_ptr<basic::Allocator> DynamicMemoryManager::allocate(const ITensor *tensor, + uint32_t capacity) { - auto find = _mem_alloc_map.find(ind); + auto find = _mem_alloc_map.find(tensor); if (find != _mem_alloc_map.end()) throw std::runtime_error("Cannot allocate memory for a tensor. It was already allocated."); - _mem_alloc_map[ind] = std::make_shared<cpu_common::Allocator>(capacity); - return _mem_alloc_map[ind]; + _mem_alloc_map[tensor] = std::make_shared<basic::Allocator>(capacity); + return _mem_alloc_map[tensor]; } -void DynamicMemoryManager::deallocate(const ir::OperandIndex &ind) +void DynamicMemoryManager::deallocate(const ITensor *tensor) { - auto find = _mem_alloc_map.find(ind); + auto find = _mem_alloc_map.find(tensor); if (find == _mem_alloc_map.end()) throw std::runtime_error("Cannot find Allocator for the requested index"); @@ -93,7 +95,7 @@ void DynamicMemoryManager::deallocate(const ir::OperandIndex &ind) void DynamicMemoryManager::deallocate(void) { - for (auto &mem_alloc : _mem_alloc_map) + for (auto &&mem_alloc : _mem_alloc_map) { // Release memory buffer of mem_alloc mem_alloc.second->release(); @@ -102,6 +104,6 @@ void DynamicMemoryManager::deallocate(void) _mem_alloc_map.clear(); } -} // namespace cpu_common +} // namespace basic } // namespace backend } // namespace onert diff --git a/runtime/onert/core/src/backend/cpu_common/MemoryPlanner.cc b/runtime/onert/core/src/backend/basic/MemoryPlanner.cc index 75c2da7d2..1c048043c 100644 --- a/runtime/onert/core/src/backend/cpu_common/MemoryPlanner.cc +++ b/runtime/onert/core/src/backend/basic/MemoryPlanner.cc @@ -22,24 +22,21 @@ namespace onert { namespace backend { -namespace cpu_common +namespace basic { void BumpPlanner::claim(const ir::OperandIndex &ind, size_t size) { - assert(size != 0); - Block blk{_capacity, size}; _mem_plans[ind] = blk; _capacity += size; - VERBOSE(BP_PLANNER) << "CLAIM(#" << ind.value() << "): " << blk.offset << ", " << blk.size - << std::endl; + VERBOSE(BP_PLANNER) << "CLAIM(" << ind << "): " << blk.offset << ", " << blk.size << std::endl; } void BumpPlanner::release(const ir::OperandIndex &ind) { - VERBOSE(BP_PLANNER) << "RELEASE(#" << ind.value() << "): " + VERBOSE(BP_PLANNER) << "RELEASE(" << ind << "): " << "NOTHING does" << std::endl; } @@ -59,11 +56,9 @@ void BumpPlanner::release(const ir::OperandIndex &ind) // the previous claim_base_offset. void FirstFitPlanner::claim(const ir::OperandIndex &ind, size_t size) { - assert(size != 0); - // Find the right position for claiming uint32_t next_offset = 0; - for (auto &mem_claim : _claim_table) + for (const auto &mem_claim : _claim_table) { auto claimed_base_offset = mem_claim.first; auto claimed_size = _mem_plans[mem_claim.second].size; @@ -81,7 +76,7 @@ void FirstFitPlanner::claim(const ir::OperandIndex &ind, size_t size) _claim_table[next_offset] = ind; _mem_plans[ind] = {next_offset, size}; - VERBOSE(FF_PLANNER) << "claim(#" << ind.value() << "): [+" << next_offset << ", " << size << "sz]" + VERBOSE(FF_PLANNER) << "claim(" << ind << "): [+" << next_offset << ", " << size << "sz]" << std::endl; if (_capacity < next_offset + size) @@ -102,7 +97,7 @@ void FirstFitPlanner::release(const ir::OperandIndex &ind) _claim_table.erase(it); - VERBOSE(FF_PLANNER) << "release(#" << index << "): [+" << offset << ", " << size << "sz]" + VERBOSE(FF_PLANNER) << "release(" << index << "): [+" << offset << ", " << size << "sz]" << std::endl; return; } @@ -111,16 +106,14 @@ void FirstFitPlanner::release(const ir::OperandIndex &ind) } WICPlanner::WICPlanner() - : _initialized(false), _capacity(0), _mem_plans(), _live_operands(), _interference_graph(), - _operands() + : _initialized(false), _capacity(0), _mem_plans(), _live_operands(), _interference_graph(), + _operands() { // DO NOTHING } void WICPlanner::claim(const ir::OperandIndex &ind, size_t size) { - assert(size != 0); - _operands.emplace(size, ind); _interference_graph[ind].insert(_interference_graph[ind].end(), _live_operands.cbegin(), _live_operands.cend()); @@ -130,13 +123,13 @@ void WICPlanner::claim(const ir::OperandIndex &ind, size_t size) } _live_operands.emplace(ind); - VERBOSE(WIC_PLANNER) << "claim(#" << ind.value() << "): [" << size << "sz]" << std::endl; + VERBOSE(WIC_PLANNER) << "claim(" << ind << "): [" << size << "sz]" << std::endl; } void WICPlanner::release(const ir::OperandIndex &ind) { _live_operands.erase(ind); - VERBOSE(WIC_PLANNER) << "release(#" << ind.value() << ")" << std::endl; + VERBOSE(WIC_PLANNER) << "release(" << ind << ")" << std::endl; } /* @@ -154,7 +147,7 @@ void WICPlanner::buildMemoryPlans() { uint32_t size = operand.first; const ir::OperandIndex &ind = operand.second; - VERBOSE(WIC_PLANNER) << "build_plan(#" << ind.value() << "): [" << size << "sz]" << std::endl; + VERBOSE(WIC_PLANNER) << "build_plan(" << ind << "): [" << size << "sz]" << std::endl; uint32_t next_offset = 0; if (_interference_graph.count(ind)) @@ -190,8 +183,8 @@ void WICPlanner::buildMemoryPlans() } _mem_plans[ind] = {next_offset, size}; - VERBOSE(WIC_PLANNER) << "alloc(#" << ind.value() << "): [+" << next_offset << ", " << size - << "sz]" << std::endl; + VERBOSE(WIC_PLANNER) << "alloc(" << ind << "): [+" << next_offset << ", " << size << "sz]" + << std::endl; if (_capacity < next_offset + size) { @@ -210,6 +203,6 @@ WICPlanner::MemoryPlans &WICPlanner::memory_plans() return _mem_plans; } -} // namespace cpu_common +} // namespace basic } // namespace backend } // namespace onert diff --git a/runtime/onert/core/src/backend/cpu_common/MemoryPlanner.h b/runtime/onert/core/src/backend/basic/MemoryPlanner.h index 7c387e542..03e977500 100644 --- a/runtime/onert/core/src/backend/cpu_common/MemoryPlanner.h +++ b/runtime/onert/core/src/backend/basic/MemoryPlanner.h @@ -19,29 +19,29 @@ * @brief      This file contains Memory Planning related classes */ -#ifndef __ONERT_BACKEND_CPU_COMMON_MEMORY_PLANNER_H__ -#define __ONERT_BACKEND_CPU_COMMON_MEMORY_PLANNER_H__ +#ifndef __ONERT_BACKEND_BASIC_MEMORY_PLANNER_H__ +#define __ONERT_BACKEND_BASIC_MEMORY_PLANNER_H__ #include <map> #include <vector> #include <unordered_set> #include <memory> -#include "backend/cpu_common/Allocator.h" -#include "backend/cpu_common/IMemoryPlanner.h" +#include "backend/basic/Allocator.h" +#include "backend/basic/IMemoryPlanner.h" #include "ir/OperandIndexMap.h" namespace onert { namespace backend { -namespace cpu_common +namespace basic { /** * @brief Class to plan memory by bump way */ -class BumpPlanner : public IMemoryPlanner +class BumpPlanner : public IMemoryPlanner<ir::OperandIndex> { public: /** @@ -74,7 +74,7 @@ private: /** * @brief Class to plan memory by firstfit way */ -class FirstFitPlanner : public IMemoryPlanner +class FirstFitPlanner : public IMemoryPlanner<ir::OperandIndex> { public: /** @@ -109,7 +109,7 @@ private: /** * @brief Class to plan memory by Weighted Interval Color algorithm */ -class WICPlanner : public IMemoryPlanner +class WICPlanner : public IMemoryPlanner<ir::OperandIndex> { public: WICPlanner(); @@ -153,8 +153,8 @@ private: std::multimap<uint32_t, ir::OperandIndex, std::greater<uint32_t>> _operands; }; -} // namespace cpu_common +} // namespace basic } // namespace backend } // namespace onert -#endif // __ONERT_BACKEND_CPU_COMMON_MEMORY_PLANNER_H__ +#endif // __ONERT_BACKEND_BASIC_MEMORY_PLANNER_H__ diff --git a/runtime/onert/core/src/backend/cpu_common/MemoryPlanner.test.cc b/runtime/onert/core/src/backend/basic/MemoryPlanner.test.cc index 5208a94d4..a32228cbe 100644 --- a/runtime/onert/core/src/backend/cpu_common/MemoryPlanner.test.cc +++ b/runtime/onert/core/src/backend/basic/MemoryPlanner.test.cc @@ -21,13 +21,13 @@ TEST(Allocator, allocate_test) { - ::onert::backend::cpu_common::Allocator allocator(1024); + ::onert::backend::basic::Allocator allocator(1024); ASSERT_NE(allocator.base(), nullptr); } TEST(BumpPlanner, claim_test) { - ::onert::backend::cpu_common::BumpPlanner planner; + ::onert::backend::basic::BumpPlanner planner; auto claim = [&planner](uint32_t index, size_t size, uint32_t expected_offset) { onert::ir::OperandIndex mem_idx(index); @@ -44,7 +44,7 @@ TEST(BumpPlanner, claim_test) TEST(FirstFitPlanner, claim_release_test) { - ::onert::backend::cpu_common::FirstFitPlanner planner; + ::onert::backend::basic::FirstFitPlanner planner; auto claim = [&planner](uint32_t index, size_t size, uint32_t expected_offset) { onert::ir::OperandIndex mem_idx(index); @@ -128,7 +128,7 @@ TEST(FirstFitPlanner, claim_release_test) TEST(WICPlanner, claim_release_test) { - ::onert::backend::cpu_common::WICPlanner planner; + ::onert::backend::basic::WICPlanner planner; auto claim = [&planner](uint32_t index, size_t size) { onert::ir::OperandIndex mem_idx(index); diff --git a/runtime/onert/core/src/backend/cpu_common/MemoryPlannerFactory.cc b/runtime/onert/core/src/backend/basic/MemoryPlannerFactory.cc index ead4f3294..7338f87b6 100644 --- a/runtime/onert/core/src/backend/cpu_common/MemoryPlannerFactory.cc +++ b/runtime/onert/core/src/backend/basic/MemoryPlannerFactory.cc @@ -22,7 +22,7 @@ namespace onert { namespace backend { -namespace cpu_common +namespace basic { MemoryPlannerFactory &MemoryPlannerFactory::get() @@ -31,7 +31,7 @@ MemoryPlannerFactory &MemoryPlannerFactory::get() return instance; } -IMemoryPlanner *MemoryPlannerFactory::create(const std::string &key) +IMemoryPlanner<ir::OperandIndex> *MemoryPlannerFactory::create(const std::string &key) { if (key == "FirstFit") { @@ -48,6 +48,6 @@ IMemoryPlanner *MemoryPlannerFactory::create(const std::string &key) return new FirstFitPlanner; // Default Planner } -} // namespace cpu_common +} // namespace basic } // namespace backend } // namespace onert diff --git a/runtime/onert/core/src/backend/cpu_common/MemoryPlannerFactory.h b/runtime/onert/core/src/backend/basic/MemoryPlannerFactory.h index d14ec13ca..b4173f749 100644 --- a/runtime/onert/core/src/backend/cpu_common/MemoryPlannerFactory.h +++ b/runtime/onert/core/src/backend/basic/MemoryPlannerFactory.h @@ -14,10 +14,11 @@ * limitations under the License. */ -#ifndef __ONERT_BACKEND_CPU_COMMON_MEMORY_PLANNER_FACTORY_H__ -#define __ONERT_BACKEND_CPU_COMMON_MEMORY_PLANNER_FACTORY_H__ +#ifndef __ONERT_BACKEND_BASIC_MEMORY_PLANNER_FACTORY_H__ +#define __ONERT_BACKEND_BASIC_MEMORY_PLANNER_FACTORY_H__ -#include "backend/cpu_common/IMemoryPlanner.h" +#include "backend/basic/IMemoryPlanner.h" +#include "MemoryPlanner.h" #include <string> @@ -25,7 +26,7 @@ namespace onert { namespace backend { -namespace cpu_common +namespace basic { class MemoryPlannerFactory @@ -37,11 +38,11 @@ private: MemoryPlannerFactory() = default; public: - IMemoryPlanner *create(const std::string &key); + IMemoryPlanner<ir::OperandIndex> *create(const std::string &key); }; -} // namespace cpu_common +} // namespace basic } // namespace backend } // namespace onert -#endif // __ONERT_BACKEND_CPU_COMMON_MEMORY_PLANNER_FACTORY_H__ +#endif // __ONERT_BACKEND_BASIC_MEMORY_PLANNER_FACTORY_H__ diff --git a/runtime/onert/core/src/backend/cpu_common/StaticTensorManager.cc b/runtime/onert/core/src/backend/basic/StaticTensorManager.cc index 440f70c93..04dbc4a6b 100644 --- a/runtime/onert/core/src/backend/cpu_common/StaticTensorManager.cc +++ b/runtime/onert/core/src/backend/basic/StaticTensorManager.cc @@ -14,65 +14,55 @@ * limitations under the License. */ -#include "backend/cpu_common/StaticTensorManager.h" +#include "backend/basic/StaticTensorManager.h" -#include "backend/cpu_common/DynamicTensorManager.h" +#include "backend/basic/DynamicTensorManager.h" +#include "backend/basic/Tensor.h" #include <util/logging.h> namespace onert { namespace backend { -namespace cpu_common +namespace basic { StaticTensorManager::StaticTensorManager(const std::shared_ptr<TensorRegistry> ®, - IDynamicTensorManager *dynamic_tensor_manager) - : _const_mgr{new DynamicMemoryManager()}, _nonconst_mgr{new MemoryManager()}, _tensors{reg}, - _dynamic_tensor_manager{dynamic_tensor_manager} + DynamicTensorManager *dynamic_tensor_manager) + : _nonconst_mgr{new MemoryManager()}, _tensors{reg}, + _dynamic_tensor_manager{dynamic_tensor_manager} { // DO NOTHING } -void StaticTensorManager::allocateConsts(void) +StaticTensorManager::StaticTensorManager(const std::shared_ptr<TensorRegistry> ®, + const std::string planner_id, + DynamicTensorManager *dynamic_tensor_manager) + : _nonconst_mgr{new MemoryManager(planner_id)}, _tensors{reg}, + _dynamic_tensor_manager{dynamic_tensor_manager} { - for (auto &pair : _tensors->native_tensors()) - { - const auto &ind = pair.first; - auto tensor = pair.second; - if (_as_constants[ind]) - { - auto mem_alloc = _const_mgr->allocate(ind, tensor->total_size()); - tensor->setBuffer(mem_alloc); - auto buffer = mem_alloc->base(); - VERBOSE(CPU_COMMON_StaticTensorManager) << "CONSTANT TENSOR(#" << ind.value() - << "): " << static_cast<void *>(buffer) - << "size : " << tensor->total_size() << std::endl; - } - } + // DO NOTHING } void StaticTensorManager::allocateNonconsts(void) { _nonconst_mgr->allocate(); - for (auto &pair : _tensors->native_tensors()) + for (auto &&pair : _tensors->native_tensors()) { const auto &ind = pair.first; - auto tensor = pair.second; + auto tensor = pair.second.get(); if (!_as_constants[ind] && !tensor->is_dynamic()) { auto *buffer = _nonconst_mgr->getBuffer(ind); tensor->setBuffer(buffer); - VERBOSE(CPU_COMMON_StaticTensorManager) << "TENSOR(#" << ind.value() - << "): " << static_cast<void *>(buffer) << std::endl; + VERBOSE(CPU_StaticTensorManager) + << "TENSOR " << ind << " : " << static_cast<void *>(buffer) << std::endl; } } } -void StaticTensorManager::deallocateConsts(void) { _const_mgr->deallocate(); } - void StaticTensorManager::deallocateNonconsts(void) { _nonconst_mgr->deallocate(); } void StaticTensorManager::buildTensor(const ir::OperandIndex &ind, @@ -80,8 +70,17 @@ void StaticTensorManager::buildTensor(const ir::OperandIndex &ind, bool as_const) { assert(!_tensors->getNativeTensor(ind)); - auto tensor = std::make_shared<Tensor>(tensor_info, backend_layout, _dynamic_tensor_manager); - _tensors->setNativeTensor(ind, tensor); + if (as_const) + { + auto tensor = std::make_unique<ExternalTensor>(tensor_info, backend_layout); + _tensors->setNativeTensor(ind, std::move(tensor)); + } + else + { + auto tensor = std::make_unique<Tensor>(tensor_info, backend_layout, + _dynamic_tensor_manager->dynamic_mem_mgr().get()); + _tensors->setNativeTensor(ind, std::move(tensor)); + } _as_constants[ind] = as_const; } @@ -113,6 +112,6 @@ void StaticTensorManager::iterate(const std::function<void(const ir::OperandInde fn(it.first); } -} // namespace cpu_common +} // namespace basic } // namespace backend } // namespace onert diff --git a/runtime/onert/core/src/backend/basic/Tensor.cc b/runtime/onert/core/src/backend/basic/Tensor.cc new file mode 100644 index 000000000..7f33d4d74 --- /dev/null +++ b/runtime/onert/core/src/backend/basic/Tensor.cc @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/basic/Tensor.h" + +#include "ir/DataType.h" +#include "backend/basic/MemoryManager.h" + +namespace onert +{ +namespace backend +{ +namespace basic +{ + +Tensor::~Tensor() {} + +void Tensor::setShape(const ir::Shape &new_shape) { _info.shape(new_shape); } + +bool Tensor::applyShape(const ir::Shape &new_shape) +{ + bool previously_dynamic = is_dynamic(); + + auto allocTensorMem = [&]() { + auto capacity = total_size(); + assert(_dynamic_mem_mgr); + auto alloc = _dynamic_mem_mgr->allocate(this, capacity); + setBuffer(alloc); + }; + + if (!previously_dynamic || buffer() == nullptr) + { + // Always set shape - when buffer with same size was already allocated, shape could differ + setShape(new_shape); + set_dynamic(); + allocTensorMem(); + } + else + { + auto previous_size = total_size(); + auto new_size = new_shape.num_elements() * ir::sizeOfDataType(data_type()); + if (previous_size != new_size) + { + assert(_dynamic_mem_mgr); + _dynamic_mem_mgr->deallocate(this); + + setShape(new_shape); + set_dynamic(); + allocTensorMem(); + } + else + { // when buffer with same size was already allocated, shape could differ + setShape(new_shape); + } + } + return true; +} + +void Tensor::deallocBuffer() +{ + if (_allocator) + { + _buffer = nullptr; + _allocator.reset(); + if (_dynamic_mem_mgr) + { + _dynamic_mem_mgr->deallocate(this); + } + } +} + +} // namespace basic +} // namespace backend +} // namespace onert + +// ExternalTensor + +namespace onert +{ +namespace backend +{ +namespace basic +{ + +// `dynamic_cast` not working across library boundaries on NDK +// With this as a key function, `dynamic_cast` works across dl +ExternalTensor::~ExternalTensor() {} + +} // namespace basic +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/basic/TensorBuilder.cc b/runtime/onert/core/src/backend/basic/TensorBuilder.cc new file mode 100644 index 000000000..4912af1f5 --- /dev/null +++ b/runtime/onert/core/src/backend/basic/TensorBuilder.cc @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <backend/basic/TensorBuilder.h> + +#include <util/logging.h> + +#include <cassert> + +namespace onert +{ +namespace backend +{ +namespace basic +{ + +TensorBuilder::TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg) + : _tensor_reg{tensor_reg}, _dynamic_tensor_mgr{new DynamicTensorManager(_tensor_reg)}, + _static_tensor_mgr{new StaticTensorManager(_tensor_reg, _dynamic_tensor_mgr.get())} +{ + /* empty */ +} + +TensorBuilder::TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg, + const std::string planner_id) + : _tensor_reg{tensor_reg}, _dynamic_tensor_mgr{new DynamicTensorManager(_tensor_reg)}, + _static_tensor_mgr{new StaticTensorManager(_tensor_reg, planner_id, _dynamic_tensor_mgr.get())} +{ + /* empty */ +} + +void TensorBuilder::registerTensorInfo(const ir::OperandIndex &ind, const ir::OperandInfo &info, + ir::Layout layout) +{ + _tensor_info_map.emplace(ind, info); + + // CPU backend supports only one layout as NHWC + assert(layout == ir::Layout::NHWC); + if (info.isDynamic()) + { + _dynamic_tensor_mgr->buildTensor(ind, info, layout); + } + else + { + _static_tensor_mgr->buildTensor(ind, info, layout, info.isConstant()); + } +} + +void TensorBuilder::notifyFirstUse(const ir::OperandIndex &ind) +{ + assert(_tensor_info_map.find(ind) != _tensor_info_map.end()); + const auto &tensor_info = _tensor_info_map.at(ind); + + if (!_tensor_reg->getNativeTensor(ind)->is_dynamic()) + { + const auto size = tensor_info.total_size(); + _static_tensor_mgr->claimPlan(ind, size); + } +} + +void TensorBuilder::notifyLastUse(const ir::OperandIndex &ind) +{ + if (!_tensor_reg->getNativeTensor(ind)->is_dynamic()) + { + _static_tensor_mgr->releasePlan(ind); + } +} + +bool TensorBuilder::isRegistered(const ir::OperandIndex &ind) const +{ + return _tensor_info_map.find(ind) != _tensor_info_map.end(); +} + +void TensorBuilder::allocate(void) { _static_tensor_mgr->allocateNonconsts(); } + +} // namespace basic +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/basic/train/TrainableTensor.cc b/runtime/onert/core/src/backend/basic/train/TrainableTensor.cc new file mode 100644 index 000000000..d09604224 --- /dev/null +++ b/runtime/onert/core/src/backend/basic/train/TrainableTensor.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <backend/basic/train/TrainableTensor.h> + +namespace onert +{ +namespace backend +{ +namespace basic +{ +namespace train +{ + +std::vector<ITensor *> TrainableTensor::optVars() +{ + std::vector<ITensor *> ret; + for (auto &&e : _opt_vars) + { + ret.emplace_back(e.get()); + } + return ret; +} + +void TrainableTensor::fillBuffer(const std::shared_ptr<ir::Data> &data) +{ + auto *buffer = _tensor.buffer(); + assert(buffer); + assert(total_size() == data->size()); + std::memcpy(buffer, data->base(), data->size()); +} + +} // namespace train +} // namespace basic +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/controlflow/Backend.h b/runtime/onert/core/src/backend/builtin/Backend.h index 670f7750f..85d389505 100644 --- a/runtime/onert/core/src/backend/controlflow/Backend.h +++ b/runtime/onert/core/src/backend/builtin/Backend.h @@ -14,16 +14,20 @@ * limitations under the License. */ -#ifndef __ONERT_BACKEND_CONTROLFLOW_BACKEND_H__ -#define __ONERT_BACKEND_CONTROLFLOW_BACKEND_H__ +#ifndef __ONERT_BACKEND_BUILTIN_BACKEND_H__ +#define __ONERT_BACKEND_BUILTIN_BACKEND_H__ +#include "BackendContext.h" #include "Config.h" -#include "ConstantInitializer.h" #include "KernelGenerator.h" #include "TensorBuilder.h" #include "Tensor.h" +#include "train/BackendContext.h" +#include "train/KernelGenerator.h" +#include "train/TensorRegistry.h" #include <backend/Backend.h> +#include <backend/train/ITrainableBackend.h> #include <memory> @@ -31,22 +35,19 @@ namespace onert { namespace backend { -namespace controlflow +namespace builtin { -class Backend : public ::onert::backend::Backend +class Backend : public ::onert::backend::Backend, public backend::train::ITrainableBackend { public: Backend() : _config{std::make_shared<Config>()} {} std::shared_ptr<IConfig> config() const override { return _config; } - std::unique_ptr<BackendContext> newContext(const ir::Graph &graph, - const std::shared_ptr<custom::IKernelBuilder> &, - bool) const override + std::unique_ptr<onert::backend::BackendContext> newContext(ContextData &&data) const override { - const auto &operands = graph.operands(); - auto context = std::make_unique<BackendContext>(this, &graph); + auto context = std::make_unique<BackendContext>(this, std::move(data)); // ControlFlow backend may not build tensors for itself because the backend's operation uses // tensors of other baceknd instead // But the backend builds tensors in case of that the controlflow operation may have constant @@ -68,10 +69,22 @@ public: auto tb = std::make_shared<TensorBuilder>(tr); context->tensor_registry = tr; context->tensor_builder = tb; - context->constant_initializer = std::make_shared<ConstantInitializer>(operands, tr); - context->kernel_gen = std::make_shared<KernelGenerator>(graph, tb->dynamicTensorManager(), tr); - context->tensor_register = nullptr; - context->optimizer = nullptr; + context->kernel_gen = std::make_shared<KernelGenerator>( + *context->graph(), tb->dynamicTensorManager(), tr, context->external_context()); + return context; + } + + std::unique_ptr<backend::train::TrainableBackendContext> + newContext(backend::train::TrainableContextData &&tdata) const override + { + const auto &tgraph = *tdata.tgraph; + auto tr = std::make_shared<train::TensorRegistry>(); + // TODO Create TensorBuilder if necessary + auto tdata_ptr = std::make_unique<backend::train::TrainableContextData>(std::move(tdata)); + auto context = std::make_unique<train::BackendContext>(this, std::move(tdata_ptr), tr); + + context->kernel_gen = + std::make_shared<train::KernelGenerator>(tgraph, tr, context->external_context()); return context; } @@ -79,8 +92,8 @@ private: std::shared_ptr<IConfig> _config; }; -} // namespace controlflow +} // namespace builtin } // namespace backend } // namespace onert -#endif // __ONERT_BACKEND_CONTROLFLOW_BACKEND_H__ +#endif // __ONERT_BACKEND_BUILTIN_BACKEND_H__ diff --git a/runtime/onert/core/src/backend/builtin/BackendContext.cc b/runtime/onert/core/src/backend/builtin/BackendContext.cc new file mode 100644 index 000000000..a66e97b6e --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/BackendContext.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "BackendContext.h" + +#include "KernelGenerator.h" +#include "backend/basic/BackendContextHelpers.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ + +ITensorRegistry *BackendContext::genTensors() { return basic::genTensors(*this); } + +FunctionMap BackendContext::genKernels() +{ + FunctionMap ret; + + for (auto &&op_ind : _data.op_order) + { + auto fn_seq = kernel_gen->generate(op_ind); + ret.emplace(op_ind, std::move(fn_seq)); + } + + basic::initConsts(*this); + + // NOTE For memory optimization, we want to free some operand data + const_cast<ir::Graph *>(graph())->operands().iterate( + [&](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); }); + + for (auto &&it : ret) + { + auto &fn_seq = it.second; + fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); }); + } + + return ret; +} + +} // namespace builtin +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/builtin/BackendContext.h b/runtime/onert/core/src/backend/builtin/BackendContext.h new file mode 100644 index 000000000..93e825239 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/BackendContext.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_BACKEND_CONTEXT_H__ +#define __ONERT_BACKEND_BUILTIN_BACKEND_CONTEXT_H__ + +#include <backend/BackendContext.h> +#include "TensorBuilder.h" +#include "KernelGenerator.h" +#include "ExternalContext.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ + +class BackendContext : public onert::backend::BackendContext +{ +public: + BackendContext(const Backend *backend, ContextData &&data, + std::shared_ptr<ITensorRegistry> tensor_registry = nullptr, + std::shared_ptr<TensorBuilder> tensor_builder = nullptr, + std::shared_ptr<KernelGenerator> kernel_gen = nullptr) + : onert::backend::BackendContext(backend, std::move(data), tensor_registry), + tensor_builder{tensor_builder}, kernel_gen{kernel_gen}, + _external_context(std::make_shared<ExternalContext>()) + { + } + + ITensorRegistry *genTensors() override; + + FunctionMap genKernels() override; + + std::shared_ptr<ExternalContext> external_context() { return _external_context; } + +private: + void planTensors(const std::vector<onert::ir::OperationIndex> &order, + const compiler::GraphLowerInfo &lower_info); + +public: + // TODO Make it private + std::shared_ptr<TensorBuilder> tensor_builder; + std::shared_ptr<KernelGenerator> kernel_gen; + +private: + // NOTE ruy context has a thread pool, and when multiple ruy contexts are created, + // the thread pool is also created in duplicate + // TODO Create one ruy context for session + std::shared_ptr<ExternalContext> _external_context; +}; + +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUILTIN_BACKEND_CONTEXT_H__ diff --git a/runtime/onert/core/src/backend/controlflow/Config.cc b/runtime/onert/core/src/backend/builtin/Config.cc index 5ec01fe11..e5f6d4c21 100644 --- a/runtime/onert/core/src/backend/controlflow/Config.cc +++ b/runtime/onert/core/src/backend/builtin/Config.cc @@ -20,18 +20,18 @@ namespace onert { namespace backend { -namespace controlflow +namespace builtin { -std::string Config::ID = "controlflow"; +std::string Config::ID = "builtin"; bool Config::initialize() { return true; } -ir::Layout Config::supportLayout(const ir::Operation &, ir::Layout frontend_layout) +ir::Layout Config::supportLayout(const ir::IOperation &, ir::Layout frontend_layout) { return frontend_layout; } -} // namespace controlflow +} // namespace builtin } // namespace backend } // namespace onert diff --git a/runtime/onert/core/src/backend/controlflow/Config.h b/runtime/onert/core/src/backend/builtin/Config.h index 6645ed59d..196b299d3 100644 --- a/runtime/onert/core/src/backend/controlflow/Config.h +++ b/runtime/onert/core/src/backend/builtin/Config.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef __ONERT_BACKEND_CONTROLFLOW_CONFIG_H__ -#define __ONERT_BACKEND_CONTROLFLOW_CONFIG_H__ +#ifndef __ONERT_BACKEND_BUILTIN_CONFIG_H__ +#define __ONERT_BACKEND_BUILTIN_CONFIG_H__ #include <backend/IConfig.h> #include <memory> @@ -25,7 +25,7 @@ namespace onert { namespace backend { -namespace controlflow +namespace builtin { class Config : public IConfig @@ -34,7 +34,7 @@ public: static std::string ID; std::string id() override { return ID; } bool initialize() override; - ir::Layout supportLayout(const ir::Operation &node, ir::Layout frontend_layout) override; + ir::Layout supportLayout(const ir::IOperation &node, ir::Layout frontend_layout) override; bool supportPermutation() override { return false; } bool supportDynamicTensor() override { @@ -46,8 +46,8 @@ public: std::unique_ptr<util::ITimer> timer() override { return std::make_unique<util::CPUTimer>(); } }; -} // namespace controlflow +} // namespace builtin } // namespace backend } // namespace onert -#endif // __ONERT_BACKEND_CONTROLFLOW_CONFIG_H__ +#endif // __ONERT_BACKEND_BUILTIN_CONFIG_H__ diff --git a/runtime/onert/core/src/backend/controlflow/UserTensor.cc b/runtime/onert/core/src/backend/builtin/ConstantInitializer.h index c8e2ebade..6b8eb3e9d 100644 --- a/runtime/onert/core/src/backend/controlflow/UserTensor.cc +++ b/runtime/onert/core/src/backend/builtin/ConstantInitializer.h @@ -14,27 +14,22 @@ * limitations under the License. */ -#include "UserTensor.h" +#ifndef __ONERT_COMPILER_BUILTIN_CONSTANT_INITIALIZER_H__ +#define __ONERT_COMPILER_BUILTIN_CONSTANT_INITIALIZER_H__ + +#include <backend/basic/ConstantInitializer.h> namespace onert { namespace backend { -namespace controlflow +namespace builtin { -size_t UserTensor::calcOffset(const ir::Coordinates &coords) const -{ - size_t rank = num_dimensions(); - size_t offset = 0; - for (size_t i = 0; i < rank; ++i) - { - offset = offset * dimension(i) + coords[i]; - } - offset *= sizeOfDataType(data_type()); - return offset; -} +using ConstantInitializer = basic::ConstantInitializer; -} // namespace controlflow +} // namespace builtin } // namespace backend } // namespace onert + +#endif // __ONERT_COMPILER_BUILTIN_CONSTANT_INITIALIZER_H__ diff --git a/runtime/onert/core/src/backend/controlflow/UserTensorRegistry.h b/runtime/onert/core/src/backend/builtin/DynamicTensorManager.h index fa2a2d54c..148948a9c 100644 --- a/runtime/onert/core/src/backend/controlflow/UserTensorRegistry.h +++ b/runtime/onert/core/src/backend/builtin/DynamicTensorManager.h @@ -14,23 +14,25 @@ * limitations under the License. */ -#ifndef __ONERT_BACKEND_CONTROLFLOW_USER_TENSOR_REGISTRY__ -#define __ONERT_BACKEND_CONTROLFLOW_USER_TENSOR_REGISTRY__ +#ifndef __ONERT_BACKEND_BUILTIN_DYNAMICTENSOR_MANAGER_H__ +#define __ONERT_BACKEND_BUILTIN_DYNAMICTENSOR_MANAGER_H__ -#include "backend/ITensorRegistry.h" -#include "UserTensor.h" +#include "TensorRegistry.h" +#include "Tensor.h" + +#include <backend/basic/DynamicTensorManager.h> namespace onert { namespace backend { -namespace controlflow +namespace builtin { -using UserTensorRegistry = PortableTensorRegistryTemplate<UserTensor>; +using DynamicTensorManager = basic::DynamicTensorManager; -} // namespace controlflow +} // namespace builtin } // namespace backend } // namespace onert -#endif // __ONERT_BACKEND_CONTROLFLOW_USER_TENSOR_REGISTRY__ +#endif // __ONERT_BACKEND_BUILTIN_DYNAMICTENSOR_MANAGER_H__ diff --git a/runtime/onert/core/src/backend/builtin/ExternalContext.h b/runtime/onert/core/src/backend/builtin/ExternalContext.h new file mode 100644 index 000000000..390dbb579 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/ExternalContext.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_EXTERNAL_CONTEXT_H__ +#define __ONERT_BACKEND_BUILTIN_EXTERNAL_CONTEXT_H__ + +#include <util/ConfigSource.h> + +#include <ruy/context.h> +#include <ruy/context_get_ctx.h> +#include <ruy/ctx.h> +#include <ruy/tune.h> + +#include <memory> + +namespace onert +{ +namespace backend +{ +namespace builtin +{ + +// TODO Unify this with cpu::ExternalContext +class ExternalContext +{ +private: + static const int kDefaultNumThreadpoolThreads = 1; + +public: + ExternalContext() : _ruy_context(std::make_unique<ruy::Context>()) + { + setMaxNumThreads(onert::util::getConfigInt(onert::util::config::RUY_THREADS)); + initPerThreadState(); + } + + void setMaxNumThreads(int max_num_threads) + { + const int target_num_threads = + max_num_threads > -1 ? max_num_threads : kDefaultNumThreadpoolThreads; + _ruy_context->set_max_num_threads(target_num_threads); + } + + ruy::Context *ruy_context() const { return _ruy_context.get(); } + +private: + void initPerThreadState() + { + // Initialize per-thread state. + const int thread_count = _ruy_context->max_num_threads(); + auto ctx = ruy::get_ctx(_ruy_context.get()); + ctx->EnsureThreadSpecificResources(thread_count); + for (int i = 0; i < thread_count; i++) + { + ctx->GetThreadSpecificTuningResolver(i)->SetTuning(ctx->explicit_tuning()); + } + } + +private: + const std::unique_ptr<ruy::Context> _ruy_context; +}; + +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUILTIN_EXTERNAL_CONTEXT_H__ diff --git a/runtime/onert/core/src/backend/builtin/IOTensor.cc b/runtime/onert/core/src/backend/builtin/IOTensor.cc new file mode 100644 index 000000000..e157a12e9 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/IOTensor.cc @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "IOTensor.h" + +#include <assert.h> + +namespace onert +{ +namespace backend +{ +namespace builtin +{ + +// `dynamic_cast` not working across library boundaries on NDK +// With this as a key function, `dynamic_cast` works across dl +IOTensor::~IOTensor() {} + +IOTensor::IOTensor(const ir::OperandInfo &info, ir::Layout layout) + : IPortableTensor{info}, _tensor{nullptr}, + _orig{std::make_unique<UserTensor>(info, layout, (uint8_t *)nullptr, 0)} +{ + _tensor = _orig.get(); +} + +void IOTensor::setTensor(IPortableTensor *tensor) +{ + assert(tensor); + assert(tensor != this); + assert(tensor->layout() == _orig->layout()); // Changing layout is not considered yet + _tensor = tensor; + if (_info.shape() != tensor->getShape()) + { + _info.shape(tensor->getShape()); + + // If input tensor shape is updated, other effective buffers use dynamic memory manager. + // Dynamic memory manager deallocate allcoated memory after each execution. + // So we should remain input tensor as dynamic if we mark it dynamic at least once. + // If dynamic memory manager maintains allocated memory after execution is finished, + // we may need to reset it as static for each setTensor call. + _info.setDynamic(); + } +} + +} // namespace builtin +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/builtin/IOTensor.h b/runtime/onert/core/src/backend/builtin/IOTensor.h new file mode 100644 index 000000000..3d684e07d --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/IOTensor.h @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_IO_TENSOR_H__ +#define __ONERT_BACKEND_BUILTIN_IO_TENSOR_H__ + +#include "backend/IPortableTensor.h" +#include "UserTensor.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ + +/** + * @brief Tensor object that indirects to the tensor it is pointing to. + * + * A executor's I/O tensor could be two types. + * + * 1. @c UserTensor, if it is the primary graph (package's input/output) + * 2. Any other derivative of @c IPortableTensor from another executor, otherwise + * + * To support these, this object indirects everything to the actual tensor pointer. + * + * IOTensor is derived from IPortableTensor, and it also have "_info" field. + * "_info" field is accessed by IPortableTensor's getter method. + * + * It assumes that IOTensor's info is always same with actual tensor's info except shape. + * setTensor() updates IOTensor's info's shape to actual tensor shape. + * Actual tensor's info should not be updated directly after setTensor() call until + * executor's execution is finished, instead it is allowed to update actual tensor's info + * indirectly by IOTensor's setter methods. + */ +class IOTensor : public IPortableTensor +{ +public: + IOTensor(const ir::OperandInfo &info, ir::Layout layout); + ~IOTensor(); + +public: + void setTensor(IPortableTensor *tensor); + +public: + uint8_t *buffer() const override { return _tensor->buffer(); } + ir::Layout layout() const override { return _orig->layout(); } + void set_dynamic() override + { + _info.setDynamic(); + _tensor->set_dynamic(); + } + void setShape(const ir::Shape &shape) override + { + _info.shape(shape); + _tensor->setShape(shape); + } + + /* + * Changes tensor shape and allocate memory since its shape was changed + * perhaps by nnfw_set_input_tensorinfo() + * + * Cases are: + * 1) static operand -> nnfw_set_input_tensorinfo() -> execute() -> execute() + * (a) (b) + * + * at (a), operand is static, tensor is static - memory dealloc is not needed + * (DynamicTensorManager cannot dealloc memory allocated by StaticTensorManager) + * at (b), operand is static, tensor is dynamic - memory dealloc is needed + * + * 2) dynamic operand -> nnfw_set_input_tensorinfo() -> execute() -> execute() + * (a) (b) + * + * at (a), operand is dynamic, tensor is dynamic - memory dealloc is not needed + * since it has not been allocated yet + * at (b), operand is dynamic, tensor is dynamic - memory dealloc is needed + */ + bool applyShape(const ir::Shape &shape) override + { + auto return_val = _tensor->applyShape(shape); + if (return_val) + { + _info.shape(shape); + _info.setDynamic(); + } + return return_val; + } + +private: + IPortableTensor *_tensor{nullptr}; //< The actual tensor that is indirected + // "_orig" has UserTensor type original tensor's info with nullptr buffer and layout, + // and "_tensor" points to "_user_tensor". + // After 1st setTensor(tensor) call, "_tensor" is updated to actual tensor + std::unique_ptr<UserTensor> _orig; //< If it is a user tensor, it is managed by this object +}; + +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUILTIN_IO_TENSOR_H__ diff --git a/runtime/onert/core/src/backend/builtin/KernelGenerator.cc b/runtime/onert/core/src/backend/builtin/KernelGenerator.cc new file mode 100644 index 000000000..00c200a92 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/KernelGenerator.cc @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "KernelGenerator.h" + +#include "kernel/IfLayer.h" +#include "kernel/PermuteLayer.h" +#include "kernel/WhileLayer.h" + +#include "exec/FunctionSequence.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ + +KernelGenerator::KernelGenerator(const ir::Graph &graph, DynamicTensorManager *dyn_tensor_manager, + const std::shared_ptr<TensorRegistry> &tensor_reg, + const std::shared_ptr<ExternalContext> &external_context) + : basic::KernelGeneratorBase{graph}, _dyn_tensor_manager{dyn_tensor_manager}, + _tensor_reg{tensor_reg}, _tensor_registries{}, _executors{nullptr}, _model_index{}, + _external_context{external_context} +{ + UNUSED_RELEASE(_graph); + UNUSED_RELEASE(_tensor_registries); + UNUSED_RELEASE(_executors); +} + +std::unique_ptr<exec::FunctionSequence> KernelGenerator::generate(ir::OperationIndex ind) +{ + assert(_dyn_tensor_manager); + assert(_tensor_reg); + + auto ret = std::make_unique<exec::FunctionSequence>(); + + // Prepare to handle dynamic tensors later + auto dyn_ctx = std::make_shared<exec::FunctionSequence::DynamicTensorCtx>(); + { + dyn_ctx->op = &_graph.operations().at(ind); + dyn_ctx->dynamic_shape_inferer = + std::make_unique<exec::DynamicShapeInferer>(_graph.operands(), _tensor_reg); + } + ret->dynamic_tensor_ctx(dyn_ctx); + + auto &op = _graph.operations().at(ind); + op.accept(*this); + assert(_return_fn); // _return_fn must have been generated + ret->append(std::move(_return_fn)); + + return ret; +} + +void KernelGenerator::visit(const ir::operation::If &node) +{ + const auto then_subg_index = node.param().then_subg_index; + const auto else_subg_index = node.param().else_subg_index; + + std::vector<backend::IPortableTensor *> input_tensors; + for (const auto &input_index : node.getInputs()) + { + auto input_tensor = getPortableTensor(input_index); + input_tensors.emplace_back(input_tensor); + } + + std::vector<backend::IPortableTensor *> output_tensors; + for (const auto &output_index : node.getOutputs()) + { + auto output_tensor = getPortableTensor(output_index); + output_tensors.emplace_back(output_tensor); + } + + // IfLayer just set Executors instead of then and else executor to avoid complexity of + // creating executor recusively + const auto cond_tensor = input_tensors.front(); + input_tensors.erase(input_tensors.begin()); + auto fn = std::make_unique<::onert::backend::builtin::kernel::IfLayer>( + cond_tensor, input_tensors, output_tensors, then_subg_index, else_subg_index, _executors, + _model_index, _external_context); + + _return_fn = std::move(fn); +} + +void KernelGenerator::visit(const ir::operation::Permute &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(0)}; + + // Add PermuteLayer + std::vector<ITensor *> output_tensors{getTensor(output_index)}; + std::vector<ITensor *> input_tensors{getTensor(input_index)}; + + auto fn = + std::make_unique<kernel::PermuteLayer>(input_tensors, output_tensors, _external_context); + _return_fn = std::move(fn); +} + +void KernelGenerator::visit(const ir::operation::While &node) +{ + const auto cond_subg_index = node.param().cond_subg_index; + const auto body_subg_index = node.param().body_subg_index; + + // This op does not support input as a constant, because builtin backend does not have + // TensorBuilder + std::vector<backend::IPortableTensor *> input_tensors; + for (const auto &input_index : node.getInputs()) + { + auto input_tensor = getPortableTensor(input_index); + input_tensors.emplace_back(input_tensor); + } + + std::vector<backend::IPortableTensor *> output_tensors; + for (const auto &output_index : node.getOutputs()) + { + auto output_tensor = getPortableTensor(output_index); + output_tensors.emplace_back(output_tensor); + } + + // WhileLayer just set Executors instead of cond and body executor to avoid complexity of + // creating executor recusively + auto fn = std::make_unique<::onert::backend::builtin::kernel::WhileLayer>( + input_tensors, output_tensors, cond_subg_index, body_subg_index, _executors, _model_index, + _dyn_tensor_manager->dynamic_mem_mgr().get(), _external_context); + + _return_fn = std::move(fn); +} + +backend::ITensor *KernelGenerator::getTensor(const ir::OperandIndex &index) +{ + // get Tensor from all tensor registries (for Permute op) + auto ret = _tensor_registries.getITensor(index); + assert(ret != nullptr); + return ret; +} + +backend::IPortableTensor *KernelGenerator::getPortableTensor(const ir::OperandIndex &index) +{ + auto ret = _tensor_reg->getPortableTensor(index); + assert(ret != nullptr); + return ret; +} + +} // namespace builtin +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/controlflow/KernelGenerator.h b/runtime/onert/core/src/backend/builtin/KernelGenerator.h index b84a810e4..3c86fe306 100644 --- a/runtime/onert/core/src/backend/controlflow/KernelGenerator.h +++ b/runtime/onert/core/src/backend/builtin/KernelGenerator.h @@ -14,60 +14,66 @@ * limitations under the License. */ -#ifndef __ONERT_BACKEND_CONTROLFLOW_KERNEL_GENERATOR_H__ -#define __ONERT_BACKEND_CONTROLFLOW_KERNEL_GENERATOR_H__ +#ifndef __ONERT_BACKEND_BUILTIN_KERNEL_GENERATOR_H__ +#define __ONERT_BACKEND_BUILTIN_KERNEL_GENERATOR_H__ -#include <backend/IKernelGenerator.h> -#include <backend/ITensorBuilder.h> -#include <exec/IExecutor.h> -#include <ir/Graph.h> -#include "TensorBuilder.h" -#include "compiler/TensorRegistries.h" +#include "DynamicTensorManager.h" +#include "ExternalContext.h" #include "TensorRegistry.h" +#include "../../compiler/TensorRegistries.h" + +#include "backend/basic/KernelGeneratorBase.h" +#include "exec/IExecutors.h" +#include "ir/Graph.h" namespace onert { namespace backend { -namespace controlflow +namespace builtin { -class KernelGenerator : public IKernelGenerator +class KernelGenerator : public basic::KernelGeneratorBase { public: - KernelGenerator(const ir::Graph &graph, IDynamicTensorManager *dyn_tensor_manager, - const std::shared_ptr<TensorRegistry> &tensor_reg); + KernelGenerator(const ir::Graph &graph, DynamicTensorManager *dyn_tensor_manager, + const std::shared_ptr<TensorRegistry> &tensor_reg, + const std::shared_ptr<ExternalContext> &external_context); void setTensorRegistries(const compiler::TensorRegistries &tensor_registries) { _tensor_registries = tensor_registries; } - void setExecutorMap(const std::shared_ptr<exec::ExecutorMap> &executor_map) + void setExecutors(const std::shared_ptr<exec::IExecutors> &executors) { // FIXME Using shared_ptr's raw pointer! - _executor_map = executor_map.get(); + _executors = executors.get(); } - using IKernelGenerator::visit; + void setModelIndex(const ir::ModelIndex &index) { _model_index = index; } + + std::unique_ptr<exec::FunctionSequence> generate(ir::OperationIndex ind) override; - void visit(const ir::OpSequence &) override; +private: void visit(const ir::operation::If &) override; void visit(const ir::operation::Permute &) override; void visit(const ir::operation::While &) override; private: - std::shared_ptr<backend::ITensor> getTensor(const ir::OperandIndex &index); + backend::ITensor *getTensor(const ir::OperandIndex &index); + backend::IPortableTensor *getPortableTensor(const ir::OperandIndex &index); private: - const ir::Graph &_graph; - IDynamicTensorManager *_dyn_tensor_manager; + DynamicTensorManager *_dyn_tensor_manager; std::shared_ptr<TensorRegistry> _tensor_reg; compiler::TensorRegistries _tensor_registries; - exec::ExecutorMap *_executor_map; + exec::IExecutors *_executors; + ir::ModelIndex _model_index; + const std::shared_ptr<ExternalContext> _external_context; }; -} // namespace controlflow +} // namespace builtin } // namespace backend } // namespace onert -#endif // __ONERT_BACKEND_CONTROLFLOW_KERNEL_GENERATOR_H__ +#endif // __ONERT_BACKEND_BUILTIN_KERNEL_GENERATOR_H__ diff --git a/runtime/onert/core/src/backend/controlflow/Tensor.h b/runtime/onert/core/src/backend/builtin/Tensor.h index ba5bafd75..d55e64161 100644 --- a/runtime/onert/core/src/backend/controlflow/Tensor.h +++ b/runtime/onert/core/src/backend/builtin/Tensor.h @@ -14,22 +14,23 @@ * limitations under the License. */ -#ifndef __ONERT_BACKEND_CONTROLFLOW_TENSOR_H__ -#define __ONERT_BACKEND_CONTROLFLOW_TENSOR_H__ +#ifndef __ONERT_BACKEND_BUILTIN_TENSOR_H__ +#define __ONERT_BACKEND_BUILTIN_TENSOR_H__ -#include <backend/cpu_common/Tensor.h> +#include <backend/basic/Tensor.h> namespace onert { namespace backend { -namespace controlflow +namespace builtin { -using Tensor = cpu_common::Tensor; +using Tensor = basic::Tensor; +using ExternalTensor = basic::ExternalTensor; -} // namespace controlflow +} // namespace builtin } // namespace backend } // namespace onert -#endif // __ONERT_BACKEND_CONTROLFLOW_TENSOR_H__ +#endif // __ONERT_BACKEND_BUILTIN_TENSOR_H__ diff --git a/runtime/onert/core/src/backend/controlflow/TensorBuilder.cc b/runtime/onert/core/src/backend/builtin/TensorBuilder.cc index e5c3f5fd5..a2f7af3ea 100644 --- a/runtime/onert/core/src/backend/controlflow/TensorBuilder.cc +++ b/runtime/onert/core/src/backend/builtin/TensorBuilder.cc @@ -24,13 +24,13 @@ namespace onert { namespace backend { -namespace controlflow +namespace builtin { TensorBuilder::TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg) - : _tensor_reg{tensor_reg}, _dynamic_tensor_mgr{new DynamicTensorManager(_tensor_reg)}, - _static_tensor_mgr{ - new cpu_common::StaticTensorManager(_tensor_reg->base_reg(), _dynamic_tensor_mgr.get())} + : _tensor_reg{tensor_reg}, _dynamic_tensor_mgr{new DynamicTensorManager(_tensor_reg->base_reg())}, + _static_tensor_mgr{ + new basic::StaticTensorManager(_tensor_reg->base_reg(), _dynamic_tensor_mgr.get())} { /* empty */ } @@ -40,15 +40,14 @@ void TensorBuilder::registerTensorInfo(const ir::OperandIndex &ind, const ir::Op { _tensor_info_map.emplace(ind, info); - _tensor_layout_map.insert({ind, backend_layout}); - + VERBOSE_F() << "cpucommon REGISTER!! " << ind << std::endl; if (info.isDynamic()) { - _dynamic_tensor_mgr->buildTensor(ind, info, _tensor_layout_map[ind]); + _dynamic_tensor_mgr->buildTensor(ind, info, backend_layout); } else { - _static_tensor_mgr->buildTensor(ind, info, _tensor_layout_map[ind], info.isConstant()); + _static_tensor_mgr->buildTensor(ind, info, backend_layout, info.isConstant()); } } @@ -58,7 +57,7 @@ void TensorBuilder::notifyFirstUse(const ir::OperandIndex &ind) if (_tensor_info_map.find(ind) == _tensor_info_map.end()) // Do not proceed for user tensors return; - const auto tensor_info = _tensor_info_map.at(ind); + const auto &tensor_info = _tensor_info_map.at(ind); if (!nativeOwnTensorAt(ind)->is_dynamic()) { @@ -89,39 +88,18 @@ bool TensorBuilder::isRegistered(const ir::OperandIndex &ind) const return _tensor_info_map.find(ind) != _tensor_info_map.end(); } -void TensorBuilder::prepare(void) -{ - _static_tensor_mgr->allocateConsts(); - _static_tensor_mgr->allocateNonconsts(); -} +void TensorBuilder::allocate(void) { _static_tensor_mgr->allocateNonconsts(); } -void TensorBuilder::allocate() +DynamicTensorManager *TensorBuilder::dynamicTensorManager(void) { - // NOTE For now nothing to do. Allocation is done in prepare stage, which is not appropriate - // This is because CPU kernels require `ITensor`s to be allocated before Kernel Generation. + return _dynamic_tensor_mgr.get(); } -std::shared_ptr<cpu_common::Tensor> TensorBuilder::nativeOwnTensorAt(const ir::OperandIndex &ind) +basic::Tensor *TensorBuilder::nativeOwnTensorAt(const ir::OperandIndex &ind) { return _tensor_reg->getNativeOwnTensor(ind); } -std::unique_ptr<ITensorManager> TensorBuilder::releaseStaticTensorManager(void) -{ - return std::move(_static_tensor_mgr); -} - -std::unique_ptr<ITensorManager> TensorBuilder::releaseDynamicTensorManager(void) -{ - return std::move(_dynamic_tensor_mgr); -} - -void TensorBuilder::setNativeUserTensor(const ir::OperandIndex &ind, - const std::shared_ptr<UserTensor> &tensor) -{ - _tensor_reg->setNativeUserTensor(ind, tensor); -} - -} // namespace controlflow +} // namespace builtin } // namespace backend } // namespace onert diff --git a/runtime/onert/core/src/backend/controlflow/TensorBuilder.h b/runtime/onert/core/src/backend/builtin/TensorBuilder.h index 2f2a2c47e..1e364c927 100644 --- a/runtime/onert/core/src/backend/controlflow/TensorBuilder.h +++ b/runtime/onert/core/src/backend/builtin/TensorBuilder.h @@ -14,29 +14,27 @@ * limitations under the License. */ -#ifndef __ONERT_BACKEND_CONTROLFLOW_TENSOR_BUILDER_H__ -#define __ONERT_BACKEND_CONTROLFLOW_TENSOR_BUILDER_H__ +#ifndef __ONERT_BACKEND_BUILTIN_TENSOR_BUILDER_H__ +#define __ONERT_BACKEND_BUILTIN_TENSOR_BUILDER_H__ -#include <backend/cpu_common/StaticTensorManager.h> -#include <backend/cpu_common/TensorRegistry.h> -#include <backend/cpu_common/Tensor.h> +#include <backend/basic/StaticTensorManager.h> +#include <backend/basic/TensorRegistry.h> +#include <backend/basic/Tensor.h> -#include <backend/ITensorBuilder.h> #include <ir/OperandIndexMap.h> #include <unordered_map> #include "DynamicTensorManager.h" -#include "UserTensorRegistry.h" namespace onert { namespace backend { -namespace controlflow +namespace builtin { -class TensorBuilder : public ITensorBuilder +class TensorBuilder { public: TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg); @@ -48,42 +46,34 @@ public: * @param[in] layout Operand data layout */ void registerTensorInfo(const ir::OperandIndex &ind, const ir::OperandInfo &info, - ir::Layout backend_layout) override; + ir::Layout backend_layout); - void notifyFirstUse(const ir::OperandIndex &) override; - void notifyLastUse(const ir::OperandIndex &) override; + void notifyFirstUse(const ir::OperandIndex &); + void notifyLastUse(const ir::OperandIndex &); - bool isRegistered(const ir::OperandIndex &) const override; + bool isRegistered(const ir::OperandIndex &) const; - void prepare(void) override; - void allocate() override; - void postFunctionPrepare() override { /* DO NOTHING */} + void allocate(void); - std::unique_ptr<ITensorManager> releaseStaticTensorManager(void) override; - - IDynamicTensorManager *dynamicTensorManager(void) override { return _dynamic_tensor_mgr.get(); } - - std::unique_ptr<ITensorManager> releaseDynamicTensorManager(void) override; + DynamicTensorManager *dynamicTensorManager(void); /** * @brief Get tensor with a specific OperandIndex. * @param ind OperandIndex for the tensor. There must exist a tensor with this ind. * If not, program will crash with assert or exception. - * @return shared_ptr<operand::Tensor> + * @return operand::Tensor * */ - std::shared_ptr<cpu_common::Tensor> nativeOwnTensorAt(const ir::OperandIndex &ind); - void setNativeUserTensor(const ir::OperandIndex &ind, const std::shared_ptr<UserTensor> &tensor); + basic::Tensor *nativeOwnTensorAt(const ir::OperandIndex &ind); private: const std::shared_ptr<TensorRegistry> _tensor_reg; std::unique_ptr<DynamicTensorManager> _dynamic_tensor_mgr; - std::unique_ptr<cpu_common::StaticTensorManager> _static_tensor_mgr; + std::unique_ptr<basic::StaticTensorManager> _static_tensor_mgr; ir::OperandIndexMap<ir::OperandInfo> _tensor_info_map; - ir::OperandIndexMap<ir::Layout> _tensor_layout_map; }; -} // namespace controlflow +} // namespace builtin } // namespace backend } // namespace onert -#endif // __ONERT_BACKEND_CONTROLFLOW_TENSOR_BUILDER_H__ +#endif // __ONERT_BACKEND_BUILTIN_TENSOR_BUILDER_H__ diff --git a/runtime/onert/core/src/backend/builtin/TensorRegistry.h b/runtime/onert/core/src/backend/builtin/TensorRegistry.h new file mode 100644 index 000000000..ae68b1318 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/TensorRegistry.h @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_TENSOR_REGISTRY_H__ +#define __ONERT_BACKEND_BUILTIN_TENSOR_REGISTRY_H__ + +#include "backend/basic/TensorRegistry.h" +#include "backend/ITensorRegistry.h" +#include "Tensor.h" +#include "IOTensor.h" +#include <assert.h> + +namespace onert +{ +namespace backend +{ +namespace builtin +{ + +/** + * @brief Tensor registry class for builtin backend + * + * This class contains three types of tensors. Two native tensors(tensors that are managed by this + * backend) and the other is migrant tensor. + * + * - NativeIOTensor - @c IOTensor managed by this backend ( in @c _base_reg ) + * - NOTE The tensor it actually points to can be from another backend + * - NativeOwnTensor - @c basic::Tensor managed by this backend ( in @c _base_reg ) + * - MigrantTensor - @c IPortableTensor managed by other backends + * + * @note @c _base_reg is used in implementation to reuse @c basic::StaticTensorManager + * + */ +class TensorRegistry : public ITensorRegistry +{ +public: + TensorRegistry() : _base_reg{new basic::TensorRegistry} {} + + ITensor *getITensor(const ir::OperandIndex &ind) override + { + auto base_tensor = _base_reg->getITensor(ind); + if (base_tensor) + return base_tensor; + return getNativeIOTensor(ind); + } + + ITensor *getNativeITensor(const ir::OperandIndex &ind) override + { + auto base_tensor = _base_reg->getNativeITensor(ind); + if (base_tensor) + return base_tensor; + return getNativeIOTensor(ind); + } + + IPortableTensor *getPortableTensor(const ir::OperandIndex &ind) + { + auto base_tensor = _base_reg->getPortableTensor(ind); + if (base_tensor) + return base_tensor; + return getNativeIOTensor(ind); + } + + IPortableTensor *getNativeTensor(const ir::OperandIndex &ind) + { + auto base_tensor = _base_reg->getNativeTensor(ind); + if (base_tensor) + return base_tensor; + return getNativeIOTensor(ind); + } + + Tensor *getNativeOwnTensor(const ir::OperandIndex &ind) + { + return _base_reg->getNativeTensor(ind); + } + + IOTensor *getNativeIOTensor(const ir::OperandIndex &ind) + { + auto tensor = _native_io_tensors.find(ind); + if (tensor != _native_io_tensors.end()) + return tensor->second.get(); + return nullptr; + } + + bool setMigrantTensor(const ir::OperandIndex &ind, IPortableTensor *tensor) override + { + assert(tensor); + assert(!getITensor(ind)); // For the ind, tensor is not registered yet + _base_reg->setMigrantTensor(ind, tensor); + return true; + } + + void setNativeOwnTensor(ir::OperandIndex ind, std::unique_ptr<Tensor> &&tensor) + { + assert(tensor); + assert(!getITensor(ind)); // For the ind, tensor is not registered yet + _base_reg->setNativeTensor(ind, std::move(tensor)); + } + + void setNativeIOTensor(ir::OperandIndex ind, std::unique_ptr<IOTensor> &&tensor) + { + assert(tensor); + assert(!getITensor(ind)); // For the ind, tensor is not registered yet + _native_io_tensors[ind] = std::move(tensor); + } + + const ir::OperandIndexMap<std::unique_ptr<IOTensor>> &native_io_tensors() + { + return _native_io_tensors; + } + std::shared_ptr<basic::TensorRegistry> base_reg() { return _base_reg; } + +private: + std::shared_ptr<basic::TensorRegistry> _base_reg; + ir::OperandIndexMap<std::unique_ptr<IOTensor>> _native_io_tensors; +}; + +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // ifndef __ONERT_BACKEND_BUILTIN_TENSOR_REGISTRY_H__ diff --git a/runtime/onert/core/src/backend/builtin/UserTensor.cc b/runtime/onert/core/src/backend/builtin/UserTensor.cc new file mode 100644 index 000000000..e260de275 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/UserTensor.cc @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "UserTensor.h" + +#include "util/Exceptions.h" +#include "ir/DataType.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ + +bool UserTensor::applyShape(const ir::Shape &new_shape) +{ + // User tensors cannot be reallocated. + auto new_size = new_shape.num_elements() * ir::sizeOfDataType(data_type()); + if (_size < new_size) + throw InsufficientBufferSizeException{"User given buffer size is too small."}; + setShape(new_shape); + return true; +} + +} // namespace builtin +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/builtin/UserTensor.h b/runtime/onert/core/src/backend/builtin/UserTensor.h new file mode 100644 index 000000000..b7f6ce091 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/UserTensor.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_USER_TENSOR_H__ +#define __ONERT_BACKEND_BUILTIN_USER_TENSOR_H__ + +#include "ir/OperandInfo.h" +#include "backend/IPortableTensor.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ + +/** + * @brief Tensor object that is for Input and Output tensors from the user. + * + * This class is a wrapped buffer that is allocated by the user. So it does not have resposibility + * on allocation nor deallocation. All the model input/output tensors are wrapped with this class + * for execution. + * + */ +class UserTensor : public IPortableTensor +{ +public: + UserTensor(const ir::OperandInfo &info, ir::Layout layout, uint8_t *buffer, size_t size) + : IPortableTensor{info}, _layout{layout}, _buffer{buffer}, _size{size} + { + } + +public: + uint8_t *buffer() const override { return _buffer; } + ir::Layout layout() const override { return _layout; } + void set_dynamic() override { _info.setDynamic(); } + void setShape(const ir::Shape &new_shape) override { _info.shape(new_shape); } + bool applyShape(const ir::Shape &) override; + +private: + ir::Layout _layout; + uint8_t *_buffer; + size_t _size; +}; + +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUILTIN_USER_TENSOR_H__ diff --git a/runtime/onert/core/src/backend/builtin/kernel/IfLayer.cc b/runtime/onert/core/src/backend/builtin/kernel/IfLayer.cc new file mode 100644 index 000000000..bf8c5fc68 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/kernel/IfLayer.cc @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "IfLayer.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace kernel +{ + +IfLayer::IfLayer(backend::IPortableTensor *cond_tensor, + const std::vector<backend::IPortableTensor *> input_tensors, + const std::vector<backend::IPortableTensor *> output_tensors, + const ir::SubgraphIndex &then_subg_index, const ir::SubgraphIndex &else_subg_index, + exec::IExecutors *executors, const ir::ModelIndex &model_index, + const std::shared_ptr<ExternalContext> &external_context) + : _cond_tensor{cond_tensor}, _input_tensors{input_tensors}, _output_tensors{output_tensors}, + _then_subg_index{then_subg_index}, _else_subg_index{else_subg_index}, _executors{executors}, + _model_index{model_index}, _external_context{external_context} +{ + // At this point, executors may not have executors of then subg and else subg +} + +void IfLayer::run() +{ + // Check condition + // // If true + // // // Set _input_tensors -> then-subg's inputs + // // // Set outputs of then-subg -> _output_tensors + // // // Run then-subg + // // Else + // // // Set _input_tensors -> else-subg's inputs + // // // Set outputs of else-subg -> _output_tensors + // // // Run else-subg + + auto getResultCond = [](backend::IPortableTensor *tensor) -> bool { + bool ret = false; + tensor->access([&](ITensor &tensor) { ret = *reinterpret_cast<bool *>(tensor.buffer()); }); + return ret; + }; + + exec::IExecutor *subg_exec = nullptr; + bool cond_result = getResultCond(_cond_tensor); + if (cond_result) + { + VERBOSE(If) << "Call to $" << _then_subg_index << " (then)" << std::endl; + subg_exec = _executors->at(_model_index, _then_subg_index); + } + else + { + VERBOSE(If) << "Call to $" << _else_subg_index << " (else)" << std::endl; + subg_exec = _executors->at(_model_index, _else_subg_index); + } + + subg_exec->execute(_input_tensors, _output_tensors, + _executors->entryExecutor()->currentOptions()); + VERBOSE(If) << "Return from $" << (cond_result ? _then_subg_index : _else_subg_index) + << std::endl; +} + +} // namespace kernel +} // namespace builtin +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/controlflow/kernel/IfLayer.h b/runtime/onert/core/src/backend/builtin/kernel/IfLayer.h index ef3a6e6f6..a9b8f2710 100644 --- a/runtime/onert/core/src/backend/controlflow/kernel/IfLayer.h +++ b/runtime/onert/core/src/backend/builtin/kernel/IfLayer.h @@ -14,17 +14,19 @@ * limitations under the License. */ -#ifndef __ONERT_BACKEND_CONTROLFLOW_KERNEL_IF_LAYER_H__ -#define __ONERT_BACKEND_CONTROLFLOW_KERNEL_IF_LAYER_H__ +#ifndef __ONERT_BACKEND_BUILTIN_KERNEL_IF_LAYER_H__ +#define __ONERT_BACKEND_BUILTIN_KERNEL_IF_LAYER_H__ -#include <backend/ITensor.h> -#include <exec/IExecutor.h> +#include <backend/IPortableTensor.h> +#include <exec/IExecutors.h> +#include <exec/IFunction.h> +#include "../ExternalContext.h" namespace onert { namespace backend { -namespace controlflow +namespace builtin { namespace kernel { @@ -32,32 +34,30 @@ namespace kernel class IfLayer : public ::onert::exec::IFunction { public: - IfLayer(const std::shared_ptr<backend::ITensor> &cond_tensor, - const std::vector<std::shared_ptr<backend::ITensor>> input_tensors, - const std::vector<std::shared_ptr<backend::ITensor>> output_tensors, - const ir::OperandIndexSequence &output_indices, const ir::Graph &graph, - const exec::DynAllocInfoMap &outputs_dyn_alloc_info, + IfLayer(backend::IPortableTensor *cond_tensor, + const std::vector<backend::IPortableTensor *> input_tensors, + const std::vector<backend::IPortableTensor *> output_tensors, const ir::SubgraphIndex &then_subg_index, const ir::SubgraphIndex &else_subg_index, - exec::ExecutorMap *executor_map); + exec::IExecutors *executors, const ir::ModelIndex &model_index, + const std::shared_ptr<ExternalContext> &external_context); public: void run() override; private: - const std::shared_ptr<backend::ITensor> _cond_tensor; - const std::vector<std::shared_ptr<backend::ITensor>> _input_tensors; - const std::vector<std::shared_ptr<backend::ITensor>> _output_tensors; - const ir::OperandIndexSequence &_output_indices; - const ir::Graph &_graph; - const exec::DynAllocInfoMap _outputs_dyn_alloc_info; + backend::IPortableTensor *_cond_tensor; + const std::vector<backend::IPortableTensor *> _input_tensors; + const std::vector<backend::IPortableTensor *> _output_tensors; const ir::SubgraphIndex _then_subg_index; const ir::SubgraphIndex _else_subg_index; - exec::ExecutorMap *_executor_map; + exec::IExecutors *_executors; + ir::ModelIndex _model_index; + const std::shared_ptr<ExternalContext> _external_context; }; } // namespace kernel -} // namespace controlflow +} // namespace builtin } // namespace backend } // namespace onert -#endif // __ONERT_BACKEND_CONTROLFLOW_KERNEL_IF_LAYER_H__ +#endif // __ONERT_BACKEND_BUILTIN_KERNEL_IF_LAYER_H__ diff --git a/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.cc b/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.cc new file mode 100644 index 000000000..600180077 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.cc @@ -0,0 +1,316 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PermuteLayer.h" + +#include "../../../exec/ShapeConverter.h" + +#include <ruy/context.h> // from @ruy + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace kernel +{ + +PermuteLayer::PermuteLayer(const std::vector<ITensor *> &src_tensors, + const std::vector<ITensor *> &dst_tensors, + const std::shared_ptr<ExternalContext> &external_context) + : _external_context{external_context}, _tasks_map{} +{ + assert(src_tensors.size() == dst_tensors.size()); + _src_tensors = src_tensors; + _dst_tensors = dst_tensors; + _src_tensors_offsets.resize(src_tensors.size()); + _dst_tensors_offsets.resize(dst_tensors.size()); +} + +void PermuteLayer::optimize() +{ + // Remove copying of tensor as nullptr + auto src_it = _src_tensors.begin(); + auto dst_it = _dst_tensors.begin(); + auto src_offsets_it = _src_tensors_offsets.begin(); + auto dst_offsets_it = _dst_tensors_offsets.begin(); + while (src_it != _src_tensors.end()) + { + if ((*src_it == *dst_it) || (*src_it == nullptr || *dst_it == nullptr)) + { + src_it = _src_tensors.erase(src_it); + dst_it = _dst_tensors.erase(dst_it); + src_offsets_it = _src_tensors_offsets.erase(src_offsets_it); + dst_offsets_it = _dst_tensors_offsets.erase(dst_offsets_it); + } + else + { + auto src = *src_it; + auto dst = *dst_it; + src_offsets_it->resize(0); + dst_offsets_it->resize(0); + if (underlying_type(src->data_type()) != underlying_type(dst->data_type())) + continue; + const auto permute_type = [&]() -> PermuteType { + if (src->getShape().rank() == 4 && src->layout() == ir::Layout::NHWC && + dst->layout() == ir::Layout::NCHW) + { + return PermuteType::NHWC_TO_NCHW; + } + else if (src->getShape().rank() == 4 && src->layout() == ir::Layout::NCHW && + dst->layout() == ir::Layout::NHWC) + { + return PermuteType::NCHW_TO_NHWC; + } + else + { + return PermuteType::COPY; + } + }(); + + // TODO Support different types + auto fn = [&](backend::ITensor &src_tensor) { + dst->access([&](backend::ITensor &dst_tensor) { + // NOTE The buffer of both tensor can be nullptr in this step + const auto data_size = ir::sizeOfDataType(src_tensor.data_type()); + + if (permute_type == PermuteType::COPY) + { + if ((!src_tensor.has_padding() && !dst_tensor.has_padding())) + { + const auto num_elements = src_tensor.getShape().num_elements(); + const int thread_count = + _external_context->ruy_context()->max_num_threads() < static_cast<int>(num_elements) + ? _external_context->ruy_context()->max_num_threads() + : num_elements; + + std::vector<PermuteWorkerTask> tasks; + auto start = 0; + for (auto i = 0; i < thread_count; ++i) + { + int end = start + (num_elements - start) / (thread_count - i); + tasks.emplace_back(src_tensor.buffer(), dst_tensor.buffer(), start * data_size, + start * data_size, (end - start) * data_size); + start = end; + } + assert(tasks.size() >= 1); + _tasks_map[src] = std::move(tasks); + } + else + { + auto loop_shape = src_tensor.getShape(); + + auto copy_axis = loop_shape.rank() - 1; + copy_axis = copy_axis < 0 ? 1 : copy_axis; + const auto copy_len = loop_shape.dim(copy_axis) * data_size; + loop_shape.dim(copy_axis) = 1; + + appendPermuteTasks(src, dst, loop_shape, copy_len); + } + } + else + { + assert(src_tensor.getShape().rank() == 4 && + (permute_type == PermuteType::NHWC_TO_NCHW || + permute_type == PermuteType::NCHW_TO_NHWC)); + const auto loop_shape = src_tensor.getShape(); + const auto copy_len = data_size; + + appendPermuteTasks(src, dst, loop_shape, copy_len); + } + }); + }; + src->access(fn); + src_it++; + dst_it++; + src_offsets_it++; + dst_offsets_it++; + } + } +} + +void PermuteLayer::appendPermuteTasks(const ITensor *src_tensor, ITensor *dst_tensor, + const ir::Shape &loop_shape, size_t size) +{ + size_t distributed_dim = 0; + auto src_shape = src_tensor->getShape(); + if (src_tensor->layout() == dst_tensor->layout()) + { + for (int i = 1; i < src_shape.rank() - 1; ++i) + { + distributed_dim = src_shape.dim(distributed_dim) < src_shape.dim(i) ? i : distributed_dim; + } + } + const auto distributed_dim_val = src_shape.dim(distributed_dim); + const int thread_count = + _external_context->ruy_context()->max_num_threads() < static_cast<int>(distributed_dim_val) + ? _external_context->ruy_context()->max_num_threads() + : distributed_dim_val; + // NOTE Do not remove this assertion. It would cause performance degradation by new threads to be + // created in the context's thread pool + assert(thread_count <= _external_context->ruy_context()->max_num_threads()); + + std::vector<PermuteWorkerTask> tasks; + int start = 0; + auto one_thread_loop_shape = loop_shape; + for (auto i = 0; i < thread_count; ++i) + { + ir::Coordinates start_coords(one_thread_loop_shape.rank()); + start_coords.set(distributed_dim, start); + int end = start + (distributed_dim_val - start) / (thread_count - i); + one_thread_loop_shape.dim(distributed_dim) = end - start; + tasks.emplace_back(*src_tensor, *dst_tensor, start_coords, one_thread_loop_shape, size); + start = end; + } + assert(tasks.size() >= 1); + _tasks_map[src_tensor] = std::move(tasks); +} + +void PermuteLayer::runPermuteTasks(backend::ITensor *src, uint8_t *dst_buffer) +{ + assert(src->getShape().num_elements() * ir::sizeOfDataType(src->data_type()) <= + src->total_size()); + std::vector<PermuteWorkerTask> &tasks = _tasks_map.at(src); + for (size_t i = 0; i < tasks.size(); ++i) + { + tasks.at(i).setBuffers(src->buffer(), dst_buffer); + } + assert(tasks.size() >= 1); + _external_context->ruy_context()->mutable_thread_pool()->Execute(tasks.size(), tasks.data()); +} + +void PermuteLayer::run() +{ + assert(_src_tensors.size() == _dst_tensors.size()); + // PermuteLayer infers dynamic shape inside itself whenever run is called for the following + // reasons: + // 1. PermuteLayer has to access dynamic tensor manager for input/output tensors of other backends + // 2. Other controlflow operation(If/While) uses this layout for copying tensors of other + // subgraphs(with other backends) + // 3. This infering code is placed here to avoid duplicated code that can be caused by above 2 + // reasons + + // check if output is not dynamic + for (size_t i = 0; i < _src_tensors.size(); ++i) + { + auto dst_tensor = _dst_tensors.at(i); + auto src_tensor = _src_tensors.at(i); + if (src_tensor->is_dynamic() || dst_tensor->is_dynamic()) + { + // getting output shape + auto src_shape = src_tensor->getShape(); + + // set output shape and output buffer + ir::Shape new_shape = + exec::convertShape(src_shape, src_tensor->layout(), dst_tensor->layout()); + + try + { + if (!dst_tensor->applyShape(new_shape)) + throw std::runtime_error{ + "Error: PermuteLayer: output's TensorManager does not support dynamic tensor"}; + assert(dst_tensor->buffer() != nullptr); + } + catch (const std::out_of_range &e) + { + std::cerr << "Error: out_of_range in PermuteLayer: output's TensorManager does not support " + "dynamic tensor" + << '\n'; + throw; + } + } + assert(exec::convertShape(src_tensor->getShape(), src_tensor->layout(), dst_tensor->layout()) == + dst_tensor->getShape()); + } + assert(_src_tensors.size() == _dst_tensors.size()); + assert(_src_tensors.size() == _src_tensors_offsets.size()); + assert(_dst_tensors.size() == _dst_tensors_offsets.size()); + auto src_it = _src_tensors.begin(); + auto dst_it = _dst_tensors.begin(); + auto src_offsets_it = _src_tensors_offsets.begin(); + auto dst_offsets_it = _dst_tensors_offsets.begin(); + while (src_it != _src_tensors.end()) + { + auto src = *src_it; + auto dst = *dst_it; + auto &src_offsets = *src_offsets_it; + auto &dst_offsets = *dst_offsets_it; + + if (src->total_size() == 0) + { + assert(dst->total_size() == 0); + } + else + { + if (src != dst) + { + // Conditions to run permutation with multithreading + // 1. The tasks for multithreathing was created + // 2. The tasks's size > 1 + // 3. Both tensors are not dynamic + // 4. Data types of both tensors are different + if (_tasks_map.find(src) == _tasks_map.end() || _tasks_map.at(src).size() == 1 || + src->is_dynamic() || dst->is_dynamic() || + underlying_type(src->data_type()) != underlying_type(dst->data_type())) + { + permute(src, dst, src->getShape().rank(), src_offsets, dst_offsets); + } + // If dst is subtensor, we have to use clEnqueueMapBuffer instead of clEnqueueWirteBuffer + else if (dst->needMemoryMap() && !dst->is_subtensor()) + { + if (!src->has_padding() && !dst->has_padding() && src->layout() == dst->layout()) + { + // This is more effective than multi-threading + src->access([&](backend::ITensor &) { dst->enqueueWriteBuffer(src->buffer(), false); }); + } + else + { + // TODO Optimize this block in case of that padding size of dst is big. + _buffers_map[dst].reserve(dst->total_size()); + auto dst_buffer = _buffers_map[dst].data(); + + src->access([&](backend::ITensor &) { runPermuteTasks(src, dst_buffer); }); + dst->enqueueWriteBuffer(dst_buffer, false); + } + } + else if (src->needMemoryMap() && !src->is_subtensor() && !src->has_padding() && + !dst->has_padding() && src->layout() == dst->layout()) + { + // This is more effective than multi-threading + assert(!dst->needMemoryMap()); + dst->access([&](backend::ITensor &) { src->enqueueReadBuffer(dst->buffer(), true); }); + } + else + { + auto fn = [&](backend::ITensor &) { + dst->access([&](backend::ITensor &) { runPermuteTasks(src, dst->buffer()); }); + }; + src->access(fn); + } + } + } + src_it++; + dst_it++; + src_offsets_it++; + dst_offsets_it++; + } +} + +} // namespace kernel +} // namespace builtin +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.h b/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.h new file mode 100644 index 000000000..cf25f5447 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.h @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_KERNEL_PERMUTELAYER_H__ +#define __ONERT_BACKEND_BUILTIN_KERNEL_PERMUTELAYER_H__ + +#include "../ExternalContext.h" +#include "../../../exec/IPermuteFunction.h" + +#include <ruy/thread_pool.h> // from @ruy + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace kernel +{ + +class PermuteLayer : public onert::exec::IPermuteFunction +{ +public: + PermuteLayer(const std::vector<ITensor *> &src_tensors, const std::vector<ITensor *> &dst_tensors, + const std::shared_ptr<ExternalContext> &external_context); + + void optimize() override; + + void run() override; + +private: + std::shared_ptr<ExternalContext> _external_context; + +private: + void appendPermuteTasks(const ITensor *src_tensor, ITensor *dst_tensor, + const ir::Shape &loop_shape, size_t size); + + void runPermuteTasks(backend::ITensor *src, uint8_t *dst_buffer); + + struct PermuteWorkerTask : ruy::Task + { + using Strides = ir::Coordinates; + + PermuteWorkerTask(const ITensor &src_tensor, ITensor &dst_tensor, + const ir::Coordinates &start_coords, const ir::Shape &loop_shape, size_t size) + : _src_buffer{src_tensor.buffer()}, _dst_buffer{dst_tensor.buffer()}, + _src_start_offset{src_tensor.calcOffset(start_coords)}, + _dst_start_offset{dst_tensor.calcOffset(start_coords)}, _src_strides{}, _dst_strides{}, + _loop_shape{loop_shape}, _size{size}, _src_layout{src_tensor.layout()}, + _dst_layout{dst_tensor.layout()}, _is_permutation{true} + { + // Set strides + setStrides(src_tensor, &_src_strides); + setStrides(dst_tensor, &_dst_strides); + + _is_permutation = (_src_layout != _dst_layout && loop_shape.rank() == 4); + } + // Constructor for a copy + PermuteWorkerTask(const uint8_t *src_buffer, uint8_t *dst_buffer, uint32_t src_start_offset, + uint32_t dst_start_offset, size_t size) + : _src_buffer{src_buffer}, _dst_buffer{dst_buffer}, _src_start_offset{src_start_offset}, + _dst_start_offset{dst_start_offset}, _src_strides{0}, _dst_strides{0}, _loop_shape{1}, + _size{size}, _src_layout{}, _dst_layout{}, _is_permutation{false} + { + // DO NOTHING + } + void setBuffers(const uint8_t *src_buffer, uint8_t *dst_buffer) + { + _src_buffer = src_buffer; + _dst_buffer = dst_buffer; + } + void Run() override + { + ShapeLoop(_loop_shape, [&](const onert::ir::Coordinates &coords) { + size_t src_offset = _src_start_offset; + size_t dst_offset = _dst_start_offset; + assert(static_cast<size_t>(_loop_shape.rank()) == coords.size()); + ir::Coordinates dst_coords = coords; + if (_is_permutation) + { + dst_coords = ir::convertCoordinates(coords, _src_layout, _dst_layout); + } + for (auto i = 0; i < _loop_shape.rank(); ++i) + { + assert(coords[i] >= 0 && dst_coords[i] >= 0); + src_offset += coords[i] * _src_strides[i]; + dst_offset += dst_coords[i] * _dst_strides[i]; + } + memcpy(_dst_buffer + dst_offset, _src_buffer + src_offset, _size); + }); + } + + private: + void setStrides(const ITensor &tensor, Strides *strides) + { + auto shape = tensor.getShape(); + const size_t rank = shape.rank(); + for (size_t i = 0; i < rank; ++i) + { + ir::Coordinates no_step(rank), one_step(rank); + one_step.set(i, 1); + if (shape.dim(i) > 1) + { + strides->set(i, tensor.calcOffset(one_step) - tensor.calcOffset(no_step)); + } + else + { + // If dimension value is 0 or 1, the stride of the dimension will be not used + // Do not call calcOffset() with coordinate value that is greater than dimension value + strides->set(i, 0); + } + assert((*strides)[i] >= 0); + } + } + + private: + const uint8_t *_src_buffer; + uint8_t *_dst_buffer; + size_t _src_start_offset; + size_t _dst_start_offset; + Strides _src_strides; + Strides _dst_strides; + const ir::Shape _loop_shape; + const size_t _size; + const ir::Layout _src_layout; + const ir::Layout _dst_layout; + bool _is_permutation; + }; + std::unordered_map<const ITensor *, std::vector<PermuteWorkerTask>> _tasks_map; +}; + +} // namespace kernel +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUILTIN_KERNEL_PERMUTELAYER_H__ diff --git a/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc new file mode 100644 index 000000000..06e5722c8 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "WhileLayer.h" + +#include "PermuteLayer.h" +#include "../../../exec/ExecutorBase.h" + +#include <misc/polymorphic_downcast.h> + +#include <algorithm> + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace kernel +{ + +WhileLayer::WhileLayer(const std::vector<backend::IPortableTensor *> input_tensors, + const std::vector<backend::IPortableTensor *> output_tensors, + const ir::SubgraphIndex &cond_subg_index, + const ir::SubgraphIndex &body_subg_index, exec::IExecutors *executors, + const ir::ModelIndex &model_index, + basic::DynamicMemoryManager *dyn_memory_manager, + const std::shared_ptr<ExternalContext> &external_context) + : _cond_subg_index{cond_subg_index}, _body_subg_index{body_subg_index}, + _input_tensors{input_tensors}, _output_tensors{output_tensors}, _executors{executors}, + _model_index{model_index}, _dyn_memory_manager{dyn_memory_manager}, + _external_context{external_context} +{ + // At this point, executors may not have executors of cond subg and body subg +} + +void WhileLayer::run() +{ + // Copy "_input_tensors" -> "cond subg inputs" + // Run cond subg + // Start loop while output of cond subg is ture + // // Copy "_input_tensors" -> "body subg inputs" in the first iteration, then copy "body subg + // outputs" -> "body subg inputs" in the second or more iterations + // // Run body subg + // // Copy "body subg outputs" -> "cond subg inputs" + // // Run cond subg + // If there is no loop copy "_input_tensors" -> "_dst_tensors", else copy "cond subg inputs" -> + // "_dst_tensors" + auto cond_exec = _executors->at(_model_index, _cond_subg_index); + auto body_exec = _executors->at(_model_index, _body_subg_index); + + // Need a temp tensor to hold the cond subgraph output + assert(cond_exec->outputSize() == 1); + auto cond_output_tensor = [&]() { + auto tensor = std::make_unique<Tensor>(cond_exec->outputInfo(0), cond_exec->outputLayout(0), + _dyn_memory_manager); + tensor->set_dynamic(); + tensor->setBuffer(_dyn_memory_manager->allocate(tensor.get(), tensor->total_size())); + return tensor; + }(); + + VERBOSE(While) << "Call to $" << _cond_subg_index << " (cond)" << std::endl; + const auto &options = _executors->entryExecutor()->currentOptions(); + cond_exec->execute(_input_tensors, {cond_output_tensor.get()}, options); + VERBOSE(While) << "Return from $" << _cond_subg_index << std::endl; + + auto getResultCond = [](backend::ITensor *tensor) -> bool { + bool ret = false; + tensor->access([&](ITensor &tensor) { ret = *reinterpret_cast<bool *>(tensor.buffer()); }); + return ret; + }; + + std::vector<ITensor *> op_inputs(_input_tensors.begin(), _input_tensors.end()); + std::vector<ITensor *> op_outputs(_output_tensors.begin(), _output_tensors.end()); + // Copying body inputs to outputs when the loop body is never executed + if (!getResultCond(cond_output_tensor.get())) + { + PermuteLayer copy_body_inputs_to_op_outputs{op_inputs, op_outputs, _external_context}; + copy_body_inputs_to_op_outputs.run(); + return; + } + + // Need some temp tensors to hold the body subgraph output + std::vector<std::unique_ptr<Tensor>> temp_outputs_o; + std::vector<IPortableTensor *> temp_outputs; + for (uint32_t i = 0; i < body_exec->outputSize(); i++) + { + auto tensor = std::make_unique<Tensor>(body_exec->outputInfo(i), body_exec->outputLayout(i), + _dyn_memory_manager); + tensor->set_dynamic(); + tensor->setBuffer(_dyn_memory_manager->allocate(tensor.get(), tensor->total_size())); + temp_outputs.push_back(tensor.get()); + temp_outputs_o.push_back(std::move(tensor)); + } + + std::vector<ITensor *> body_outputs(temp_outputs.begin(), temp_outputs.end()); + PermuteLayer copy_body_outputs_to_op_outputs{body_outputs, op_outputs, _external_context}; + + const auto body_execute_with_op_inputs = [&]() { + VERBOSE(While) << "Call to $" << _body_subg_index << " (body)" << std::endl; + body_exec->execute(_input_tensors, temp_outputs, options); + VERBOSE(While) << "Return from $" << _body_subg_index << std::endl; + }; + + const auto body_execute_with_body_outputs = [&]() { + VERBOSE(While) << "Call to $" << _body_subg_index << " (body)" << std::endl; + body_exec->execute(_output_tensors, temp_outputs, options); + VERBOSE(While) << "Return from $" << _body_subg_index << std::endl; + }; + + std::function<void()> body_execute = body_execute_with_op_inputs; + const auto cond_execute = [&]() { + VERBOSE(While) << "Call to $" << _cond_subg_index << " (cond)" << std::endl; + cond_exec->execute(_output_tensors, {cond_output_tensor.get()}, options); + VERBOSE(While) << "Return from $" << _cond_subg_index << std::endl; + }; + + // Loop while Cond subgraph's output is true + while (getResultCond(cond_output_tensor.get())) + { + body_execute(); + copy_body_outputs_to_op_outputs.run(); + cond_execute(); + body_execute = body_execute_with_body_outputs; + } + + // Clean-up the temp tensors + _dyn_memory_manager->deallocate(cond_output_tensor.get()); + for (auto &&tensor : temp_outputs) + { + _dyn_memory_manager->deallocate(tensor); + } +} + +} // namespace kernel +} // namespace builtin +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/controlflow/kernel/WhileLayer.h b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.h index ebca8acdc..40ca4fe23 100644 --- a/runtime/onert/core/src/backend/controlflow/kernel/WhileLayer.h +++ b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.h @@ -14,20 +14,23 @@ * limitations under the License. */ -#ifndef __ONERT_BACKEND_CONTROLFLOW_KERNEL_WHILE_LAYER_H__ -#define __ONERT_BACKEND_CONTROLFLOW_KERNEL_WHILE_LAYER_H__ +#ifndef __ONERT_BACKEND_BUILTIN_KERNEL_WHILE_LAYER_H__ +#define __ONERT_BACKEND_BUILTIN_KERNEL_WHILE_LAYER_H__ -#include <backend/ITensor.h> -#include <exec/IExecutor.h> +#include <backend/IPortableTensor.h> +#include <exec/IExecutors.h> #include <exec/IFunction.h> #include <ir/OperandIndexSequence.h> #include <ir/Graph.h> +#include "../ExternalContext.h" + +#include "backend/basic/MemoryManager.h" namespace onert { namespace backend { -namespace controlflow +namespace builtin { namespace kernel { @@ -35,12 +38,12 @@ namespace kernel class WhileLayer : public ::onert::exec::IFunction { public: - WhileLayer(const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors, - const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors, - const ir::OperandIndexSequence &output_indices, const ir::Graph &graph, - const exec::DynAllocInfoMap &outputs_dyn_alloc_info, + WhileLayer(const std::vector<backend::IPortableTensor *> input_tensors, + const std::vector<backend::IPortableTensor *> output_tensors, const ir::SubgraphIndex &cond_subg_index, const ir::SubgraphIndex &body_subg_index, - exec::ExecutorMap *executor_map); + exec::IExecutors *executors, const ir::ModelIndex &model_index, + basic::DynamicMemoryManager *dyn_memory_manager, + const std::shared_ptr<ExternalContext> &external_context); public: void run() override; @@ -48,17 +51,17 @@ public: private: const ir::SubgraphIndex _cond_subg_index; const ir::SubgraphIndex _body_subg_index; - const ir::OperandIndexSequence &_output_indices; - const ir::Graph &_graph; - const std::vector<std::shared_ptr<backend::ITensor>> _input_tensors; - const std::vector<std::shared_ptr<backend::ITensor>> _output_tensors; - const exec::DynAllocInfoMap _outputs_dyn_alloc_info; - exec::ExecutorMap *_executor_map; + const std::vector<backend::IPortableTensor *> _input_tensors; + const std::vector<backend::IPortableTensor *> _output_tensors; + exec::IExecutors *_executors; + const ir::ModelIndex _model_index; + basic::DynamicMemoryManager *_dyn_memory_manager; // For generating temp tensors + const std::shared_ptr<ExternalContext> _external_context; }; } // namespace kernel -} // namespace controlflow +} // namespace builtin } // namespace backend } // namespace onert -#endif // __ONERT_BACKEND_CONTROLFLOW_KERNEL_WHILE_LAYER_H__ +#endif // __ONERT_BACKEND_BUILTIN_KERNEL_WHILE_LAYER_H__ diff --git a/runtime/onert/core/src/backend/builtin/train/BackendContext.cc b/runtime/onert/core/src/backend/builtin/train/BackendContext.cc new file mode 100644 index 000000000..69483eade --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/BackendContext.cc @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "BackendContext.h" + +#include "backend/basic/train/TrainableBackendContextHelpers.h" +#include "exec/FunctionSequence.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ + +backend::ITensorRegistry *BackendContext::genTensors() +{ + // For now, there is no need to generate tensors for forwarding. + // builtin train backend handles 3 operators: `Permute`, `IF`, `WHILE`. + // `Permute`: Tensor generation is not required. + // `IF`, `WHILE`: Not supported yet + return tensor_registry().get(); +} + +backend::train::ITensorRegistry *BackendContext::genTrainingTensors() +{ + // For now, there is no need to generate tensors for backwarding. + return tensor_registry().get(); +} + +backend::train::FunctionMap BackendContext::genKernels() +{ + backend::train::FunctionMap ret; + + for (auto &&op_ind : _tdata->op_order) + { + auto tn_seq = kernel_gen->generate(op_ind); + ret.emplace(op_ind, std::move(tn_seq)); + } + + trainable_graph()->operands().iterate( + [&](const ir::OperandIndex &ind, const ir::Operand &operand) { + if (!external_operands().contains(ind) && operand.isConstant()) + { + throw std::runtime_error( + "BackendContext: builtin backend does not support updatable weights yet"); + } + }); + + // TODO Enable prepare() + // for (auto &&it : ret) + // { + // auto &fn_seq = it.second; + // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); }); + // } + + return ret; +} + +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/builtin/train/BackendContext.h b/runtime/onert/core/src/backend/builtin/train/BackendContext.h new file mode 100644 index 000000000..4782756c3 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/BackendContext.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_BACKEND_CONTEXT_H__ +#define __ONERT_BACKEND_BUILTIN_TRAIN_BACKEND_CONTEXT_H__ + +#include <backend/train/TrainableBackendContext.h> + +#include "KernelGenerator.h" +#include "../ExternalContext.h" +#include "../TensorBuilder.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ + +class BackendContext : public backend::train::TrainableBackendContext +{ +public: + BackendContext(const backend::train::ITrainableBackend *backend, + std::unique_ptr<backend::train::TrainableContextData> &&data, + std::shared_ptr<backend::train::ITensorRegistry> tensor_registry = nullptr, + std::shared_ptr<TensorBuilder> tensor_builder = nullptr, + std::shared_ptr<KernelGenerator> kernel_gen = nullptr) + : backend::train::TrainableBackendContext(backend, std::move(data), tensor_registry), + kernel_gen{kernel_gen}, _external_context(new ExternalContext), + _tensor_builder{tensor_builder} + { + } + + backend::ITensorRegistry *genTensors() override; + backend::train::ITensorRegistry *genTrainingTensors() override; + +public: + backend::train::FunctionMap genKernels() override; + + std::shared_ptr<ExternalContext> external_context() { return _external_context; } + +public: + // TODO Make it private + std::shared_ptr<KernelGenerator> kernel_gen; + +private: + // NOTE ruy context has a thread pool, and when multiple ruy contexts are created, + // the thread pool is also created in duplicate + // TODO Create one ruy context for session + std::shared_ptr<ExternalContext> _external_context; + +private: + std::shared_ptr<TensorBuilder> _tensor_builder; +}; + +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUILTIN_TRAIN_BACKEND_CONTEXT_H__ diff --git a/runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc new file mode 100644 index 000000000..32032de4a --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "KernelGenerator.h" + +#include "kernel/PermuteLayer.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ + +KernelGenerator::KernelGenerator(const ir::train::TrainableGraph &tgraph, + const std::shared_ptr<TensorRegistry> &tensor_reg, + const std::shared_ptr<ExternalContext> &external_context) + : KernelGeneratorBase{tgraph}, _tensor_reg{tensor_reg}, _external_context(external_context) +{ +} + +std::unique_ptr<exec::train::TrainableFnSequence> KernelGenerator::generate(ir::OperationIndex ind) +{ + auto ret = std::make_unique<exec::train::TrainableFnSequence>(); + const auto &op = _tgraph.operation(ind); + op.accept(*this); + // _return_fn must have been generated + if (_return_fn == nullptr) + { + throw std::runtime_error(op.name() + " op does not supported trainable kernel yet"); + } + + ret->_functions.emplace_back(std::move(_return_fn)); + + return ret; +} + +void KernelGenerator::visit(const ir::train::operation::Permute &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(0)}; + + // Add PermuteLayer + std::vector<ITensor *> output_tensors{getTensor(output_index)}; + std::vector<ITensor *> input_tensors{getTensor(input_index)}; + + std::vector<ITensor *> output_back_prop_tensors; + std::vector<ITensor *> input_back_prop_tensors; + + auto input_back_prop_tensor = getBackPropTensor(input_index); + auto output_back_prop_tensor = getBackPropTensor(output_index); + output_back_prop_tensors.emplace_back(output_back_prop_tensor); + input_back_prop_tensors.emplace_back(input_back_prop_tensor); + + // NOTE The output buffers of IOTensors are not essential for training. If there + // is no output buffer provided by the user, permute is not performed. + bool ignore_forward_in_training = false; + for (const auto dst_tensor : output_tensors) + { + if (dst_tensor->buffer() == nullptr || dst_tensor->total_size() == 0) + ignore_forward_in_training = true; + } + + auto fn = std::make_unique<kernel::PermuteLayer>( + input_tensors, output_tensors, input_back_prop_tensors, output_back_prop_tensors, + ignore_forward_in_training, _external_context); + + _return_fn = std::move(fn); +} + +backend::ITensor *KernelGenerator::getTensor(const ir::OperandIndex &index) +{ + // Get Tensor from all tensor registries (for Permute op) + auto ret = _tensor_registries.getITensor(index); + assert(ret != nullptr); + return ret; +} + +backend::ITensor *KernelGenerator::getBackPropTensor(const ir::OperandIndex &index) +{ + // Get back propagation Tensor from all tensor registries (for Permute op) + auto ret = _tensor_registries.getBackPropITensor(index); + return ret; +} + +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/builtin/train/KernelGenerator.h b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.h new file mode 100644 index 000000000..162955b6d --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUTIN_TRAIN_KERNEL_GENERATOR_H__ +#define __ONERT_BACKEND_BUTIN_TRAIN_KERNEL_GENERATOR_H__ + +#include "../ExternalContext.h" +#include "../train/TensorRegistry.h" +#include "../../../compiler/train/TensorRegistries.h" + +#include <backend/train/KernelGeneratorBase.h> +#include <exec/train/TrainableFnSequence.h> +#include <ir/train/TrainableGraph.h> + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ + +class KernelGenerator : public backend::train::KernelGeneratorBase +{ +public: + KernelGenerator(const ir::train::TrainableGraph &tgraph, + const std::shared_ptr<TensorRegistry> &tensor_reg, + const std::shared_ptr<ExternalContext> &external_context); + + std::unique_ptr<exec::train::TrainableFnSequence> generate(ir::OperationIndex ind) override; + + void setTensorRegistries(const compiler::train::TensorRegistries &tensor_registries) + { + _tensor_registries = tensor_registries; + } + + void setWholeGraphOutputs(const ir::OperandIndexSequence &outputs) + { + _whole_graph_outputs = outputs; + } + +private: + void visit(const ir::train::operation::Permute &) override; + +private: + backend::ITensor *getTensor(const ir::OperandIndex &index); + backend::ITensor *getBackPropTensor(const ir::OperandIndex &index); + +private: + std::shared_ptr<TensorRegistry> _tensor_reg; + compiler::train::TensorRegistries _tensor_registries; + const std::shared_ptr<ExternalContext> _external_context; + ir::OperandIndexSequence _whole_graph_outputs; +}; + +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUTIN_TRAIN_KERNEL_GENERATOR_H__ diff --git a/runtime/onert/core/src/backend/builtin/train/Tensor.h b/runtime/onert/core/src/backend/builtin/train/Tensor.h new file mode 100644 index 000000000..baf42796c --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/Tensor.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_TRAINABLE_TENSOR_H__ +#define __ONERT_BACKEND_BUILTIN_TRAIN_TRAINABLE_TENSOR_H__ + +#include <backend/basic/train/TrainableTensor.h> + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ + +using TrainableTensor = basic::train::TrainableTensor; +using BackPropTensor = basic::Tensor; +using GradientTensor = basic::Tensor; + +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUILTIN_TRAIN_TRAINABLE_TENSOR_H__ diff --git a/runtime/onert/core/src/backend/builtin/train/TensorRegistry.h b/runtime/onert/core/src/backend/builtin/train/TensorRegistry.h new file mode 100644 index 000000000..7c8166bde --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/TensorRegistry.h @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__ +#define __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__ + +#include <backend/train/ITensorRegistry.h> + +#include "../IOTensor.h" +#include "../Tensor.h" +#include "Tensor.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ + +using BaseTensorRegistry = + backend::train::PortableTensorRegistryTemplate<Tensor, TrainableTensor, BackPropTensor, + GradientTensor>; + +class TensorRegistry : public backend::train::ITensorRegistry +{ +public: + TensorRegistry() : _base_reg{new BaseTensorRegistry} {} + + ITensor *getITensor(const ir::OperandIndex &index) override + { + auto base_tensor = _base_reg->getITensor(index); + if (base_tensor) + return base_tensor; + return getNativeIOTensor(index); + } + + ITensor *getNativeITensor(const ir::OperandIndex &index) override + { + auto base_tensor = _base_reg->getNativeITensor(index); + if (base_tensor) + return base_tensor; + return getNativeIOTensor(index); + } + + IPortableTensor *getPortableTensor(const ir::OperandIndex &index) + { + auto base_tensor = _base_reg->getPortableTensor(index); + if (base_tensor) + return base_tensor; + return getNativeIOTensor(index); + } + + IOTensor *getNativeIOTensor(const ir::OperandIndex &index) + { + auto tensor = _native_io_tensors.find(index); + if (tensor != _native_io_tensors.end()) + return tensor->second.get(); + return nullptr; + } + + ITensor *getBackPropITensor(const ir::OperandIndex &index) override + { + return _base_reg->getBackPropTensor(index); + } + + ITensor *getGradientITensor(const ir::OperandIndex &index) override + { + return _base_reg->getGradientTensor(index); + } + + BackPropTensor *getBackPropTensor(const ir::OperandIndex &index) + { + return _base_reg->getBackPropTensor(index); + } + + bool setMigrantTensor(const ir::OperandIndex &index, IPortableTensor *tensor) override + { + assert(tensor); + assert(!getITensor(index)); // For the index, tensor is not registered yet + _base_reg->setMigrantTensor(index, tensor); + return true; + } + + void iterateTrainableTensors( + const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &) + const override + { + // DO NOTHING + // Builtin tensor registry does not have trainable tensor. + } + + void setBackPropTensor(const ir::OperandIndex &index, std::unique_ptr<BackPropTensor> tensor) + { + _base_reg->setBackPropTensor(index, std::move(tensor)); + } + + void setGradientTensor(const ir::OperandIndex &index, std::unique_ptr<GradientTensor> tensor) + { + _base_reg->setGradientTensor(index, std::move(tensor)); + } + + void setNativeIOTensor(ir::OperandIndex index, std::unique_ptr<IOTensor> &&tensor) + { + assert(tensor); + assert(!getITensor(index)); // For the index, tensor is not registered yet + _native_io_tensors[index] = std::move(tensor); + } + + const ir::OperandIndexMap<std::unique_ptr<IOTensor>> &native_io_tensors() + { + return _native_io_tensors; + } + std::shared_ptr<BaseTensorRegistry> base_reg() { return _base_reg; } + +private: + std::shared_ptr<BaseTensorRegistry> _base_reg; + ir::OperandIndexMap<std::unique_ptr<IOTensor>> _native_io_tensors; +}; + +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__ diff --git a/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc new file mode 100644 index 000000000..dce7482e2 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc @@ -0,0 +1,87 @@ + + +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PermuteLayer.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ +namespace kernel +{ + +PermuteLayer::PermuteLayer(const std::vector<ITensor *> &src_tensors, + const std::vector<ITensor *> &dst_tensors, + const std::vector<ITensor *> &input_back_prop_tensors, + const std::vector<ITensor *> &output_back_prop_tensors, + bool ignore_forward_in_training, + const std::shared_ptr<ExternalContext> &external_context) + : builtin::kernel::PermuteLayer{src_tensors, dst_tensors, external_context}, + _input_back_prop_tensors{input_back_prop_tensors}, + _output_back_prop_tensors{output_back_prop_tensors}, + _ignore_forward_in_training{ignore_forward_in_training} +{ + assert(input_back_prop_tensors.size() == output_back_prop_tensors.size()); + assert(src_tensors.size() == dst_tensors.size()); +} + +void PermuteLayer::optimize() +{ + builtin::kernel::PermuteLayer::optimize(); + + // TODO Calculate offsets of back propagation tensors if necessary +} + +void PermuteLayer::forward(bool) +{ + if (_ignore_forward_in_training) + return; + + builtin::kernel::PermuteLayer::run(); +} + +void PermuteLayer::backward() +{ + for (uint32_t i = 0; i < _output_back_prop_tensors.size(); ++i) + { + auto src_back_prop = _output_back_prop_tensors.at(i); + auto dst_back_prop = _input_back_prop_tensors.at(i); + + // NOTE The back propagation tensors corresponding to inputs/outputs of model are nullptr + // because permuting those tensors is meaningless + if (src_back_prop && dst_back_prop) + { + const auto rank = src_back_prop->getShape().rank(); + auto output_offsets = _dst_tensors_offsets.at(i); + auto input_offsets = _src_tensors_offsets.at(i); + + exec::IPermuteFunction::permute(src_back_prop, dst_back_prop, rank, output_offsets, + input_offsets); + } + } +} + +} // namespace kernel +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h new file mode 100644 index 000000000..1dc221b09 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_KERNEL_PERMUTELAYER_H__ +#define __ONERT_BACKEND_BUILTIN_TRAIN_KERNEL_PERMUTELAYER_H__ + +#include "../../kernel/PermuteLayer.h" + +#include "exec/train/ITrainableFunction.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ +namespace kernel +{ + +class PermuteLayer : public builtin::kernel::PermuteLayer, public exec::train::ITrainableFunction +{ +public: + PermuteLayer(const std::vector<ITensor *> &src_tensors, const std::vector<ITensor *> &dst_tensors, + const std::vector<ITensor *> &input_back_prop_tensors, + const std::vector<ITensor *> &output_back_prop_tensors, + bool ignore_forward_in_training, + const std::shared_ptr<ExternalContext> &external_context); + + void optimize() override; + + void forward(bool training) override; + void backward() override; + +private: + std::vector<ITensor *> _input_back_prop_tensors; + std::vector<ITensor *> _output_back_prop_tensors; + bool _ignore_forward_in_training; +}; + +} // namespace kernel +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUILTIN_TRAIN_KERNEL_PERMUTELAYER_H__ diff --git a/runtime/onert/core/src/backend/controlflow/ConstantInitializer.h b/runtime/onert/core/src/backend/controlflow/ConstantInitializer.h deleted file mode 100644 index e21a8f357..000000000 --- a/runtime/onert/core/src/backend/controlflow/ConstantInitializer.h +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __ONERT_COMPILER_CONTROLFLOW_CONSTANT_INITIALIZER_H__ -#define __ONERT_COMPILER_CONTROLFLOW_CONSTANT_INITIALIZER_H__ - -#include "TensorRegistry.h" - -#include <backend/IConstantInitializer.h> -#include <ir/Operands.h> - -namespace onert -{ -namespace backend -{ -namespace controlflow -{ - -class ConstantInitializer : public IConstantInitializer -{ -public: - ConstantInitializer(const ir::Operands &operands, - const std::shared_ptr<ITensorRegistry> &tensor_reg) - : IConstantInitializer{operands}, _tensor_reg{tensor_reg} - { - } - -private: - std::shared_ptr<ITensorRegistry> tensor_registry() const override { return _tensor_reg; } - -private: - std::shared_ptr<ITensorRegistry> _tensor_reg; -}; - -} // namespace controlflow -} // namespace backend -} // namespace onert - -#endif // __ONERT_COMPILER_CONTROLFLOW_CONSTANT_INITIALIZER_H__ diff --git a/runtime/onert/core/src/backend/controlflow/DynamicTensorManager.cc b/runtime/onert/core/src/backend/controlflow/DynamicTensorManager.cc deleted file mode 100644 index 1288e4c96..000000000 --- a/runtime/onert/core/src/backend/controlflow/DynamicTensorManager.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "DynamicTensorManager.h" - -#include "util/logging.h" -#include "util/Exceptions.h" -#include "ir/DataType.h" - -namespace onert -{ -namespace backend -{ -namespace controlflow -{ - -DynamicTensorManager::DynamicTensorManager(const std::shared_ptr<TensorRegistry> &tensors) - : _dynamic_mem_mgr{new cpu_common::DynamicMemoryManager()}, _tensors{tensors} -{ - // DO NOTHING -} - -void DynamicTensorManager::applyShape(const ir::OperandIndex &ind, const ir::Shape &new_shape) -{ - // NOTE Handle user tensors first - auto user_tensor = _tensors->getNativeUserTensor(ind); - if (user_tensor) - { - // User tensors cannot be reallocated. - auto buffer_size = user_tensor->total_size(); - auto new_size = new_shape.num_elements() * sizeOfDataType(user_tensor->data_type()); - if (buffer_size < new_size) - throw InsufficientBufferSizeException{"Output buffer size is less than output tensor size"}; - user_tensor->setShape(new_shape); - return; - } - - // NOTE Then handle own tensors - auto tensor = _tensors->getNativeOwnTensor(ind); - assert(tensor); - - bool previously_dynamic = tensor->is_dynamic(); - - auto allocTensorMem = [&](bool overwrite = false) { - auto capacity = tensor->total_size(); - auto alloc = _dynamic_mem_mgr->allocate(ind, capacity); - - if (overwrite) - tensor->overwriteBuffer(alloc); - else - tensor->setBuffer(alloc); - }; - - if (!previously_dynamic) - { - // TODO deallocate tensor->buffer() - // issue is that staticTensorManager might have allocate this memory - tensor->setShape(new_shape); - tensor->set_dynamic(); - allocTensorMem(true); - } - else if (tensor->buffer() == nullptr) - { - tensor->setShape(new_shape); - tensor->set_dynamic(); - allocTensorMem(); - } - // when buffer was already allocated and new_shape requires different size - else - { - auto previous_size = tensor->total_size(); - auto new_size = new_shape.num_elements() * sizeOfDataType(tensor->data_type()); - if (previous_size != new_size) - { - _dynamic_mem_mgr->deallocate(ind); - - tensor->setShape(new_shape); - tensor->set_dynamic(); - allocTensorMem(true); - } - else - { // when buffer with same size was already allocated, shape could differ - tensor->setShape(new_shape); - } - } -} - -void DynamicTensorManager::buildTensor(const ir::OperandIndex &ind, - const ir::OperandInfo &tensor_info, - ir::Layout backend_layout) -{ - auto tensor = std::make_shared<cpu_common::Tensor>(tensor_info, backend_layout, this); - _tensors->setNativeOwnTensor(ind, tensor); -} - -void DynamicTensorManager::planDealloc(ir::OperationIndex op_ind, ir::OperandIndex operand_ind) -{ - _dealloc_tensor_map[op_ind].emplace(operand_ind); -} - -void DynamicTensorManager::deallocInput(ir::OperationIndex op_ind) -{ - auto find = _dealloc_tensor_map.find(op_ind); - if (find == _dealloc_tensor_map.end()) - return; - - auto &input_set = find->second; - for (auto input_ind : input_set) - { - if (!_tensors->getNativeTensor(input_ind)->is_dynamic()) - continue; - - _dynamic_mem_mgr->deallocate(input_ind); - VERBOSE(DynamicTensorManager) << "Deallocating #" << input_ind.value() - << " (input of op_ind: " << op_ind.value() << ")" << std::endl; - } -} - -void DynamicTensorManager::deallocSubgraphOutput(ir::OperandIndex output_ind) -{ - if (!_tensors->getNativeTensor(output_ind)->is_dynamic()) - return; - - _dynamic_mem_mgr->deallocate(output_ind); - VERBOSE(DynamicTensorManager) << "Deallocating #" << output_ind.value() - << " (output of a subgraph)" << std::endl; -} - -} // namespace controlflow -} // namespace backend -} // namespace onert diff --git a/runtime/onert/core/src/backend/controlflow/DynamicTensorManager.h b/runtime/onert/core/src/backend/controlflow/DynamicTensorManager.h deleted file mode 100644 index dbe388ba2..000000000 --- a/runtime/onert/core/src/backend/controlflow/DynamicTensorManager.h +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __ONERT_BACKEND_CONTROLFLOW_DYNAMICTENSOR_MANAGER_H__ -#define __ONERT_BACKEND_CONTROLFLOW_DYNAMICTENSOR_MANAGER_H__ - -#include "TensorRegistry.h" -#include "Tensor.h" - -#include <backend/IDynamicTensorManager.h> -#include <backend/cpu_common/MemoryManager.h> -#include <ir/OperandInfo.h> -#include <ir/Operation.h> -#include <ir/Index.h> - -namespace onert -{ -namespace backend -{ -namespace controlflow -{ - -/** - * @brief Class to manage dynamic tensor and its memory - */ -class DynamicTensorManager : public backend::IDynamicTensorManager -{ -public: - DynamicTensorManager(const std::shared_ptr<TensorRegistry> &tensors); - - virtual ~DynamicTensorManager() = default; - - void applyShape(const ir::OperandIndex &ind, const ir::Shape &new_shape) override; - - void buildTensor(const ir::OperandIndex &ind, const ir::OperandInfo &tensor_info, - ir::Layout backend_layout); - - void planDealloc(ir::OperationIndex op_ind, ir::OperandIndex operand_ind) override; - void deallocInput(ir::OperationIndex op_ind) override; - void deallocSubgraphOutput(ir::OperandIndex ind) override; - -private: - /** - * @brief Memory manager for dynamic tensor. - * @todo DynamicMemoryManager is not optimized. Optimized one is needed - */ - std::shared_ptr<cpu_common::DynamicMemoryManager> _dynamic_mem_mgr; - const std::shared_ptr<TensorRegistry> _tensors; - - // contains list of dynamic tensor index, which can be deallocated after running operation - // note: this map could contain static tensor index too. Careful use is required. - std::unordered_map<ir::OperationIndex, std::unordered_set<ir::OperandIndex>> _dealloc_tensor_map; -}; - -} // namespace controlflow -} // namespace backend -} // namespace onert - -#endif // __ONERT_BACKEND_CONTROLFLOW_DYNAMICTENSOR_MANAGER_H__ diff --git a/runtime/onert/core/src/backend/controlflow/KernelGenerator.cc b/runtime/onert/core/src/backend/controlflow/KernelGenerator.cc deleted file mode 100644 index de5a6a5f6..000000000 --- a/runtime/onert/core/src/backend/controlflow/KernelGenerator.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "KernelGenerator.h" - -#include <backend/BackendContext.h> -#include <util/Utils.h> -#include "kernel/IfLayer.h" -#include "kernel/WhileLayer.h" -#include "kernel/PermuteLayer.h" -#include "exec/ExecutorBase.h" -#include "exec/FunctionSequence.h" - -namespace onert -{ -namespace backend -{ -namespace controlflow -{ - -KernelGenerator::KernelGenerator(const ir::Graph &graph, IDynamicTensorManager *dyn_tensor_manager, - const std::shared_ptr<TensorRegistry> &tensor_reg) - : _graph{graph}, _dyn_tensor_manager{dyn_tensor_manager}, _tensor_reg{tensor_reg}, - _tensor_registries{}, _executor_map{nullptr} -{ - UNUSED_RELEASE(_graph); - UNUSED_RELEASE(_tensor_registries); - UNUSED_RELEASE(_executor_map); -} - -void KernelGenerator::visit(const ir::OpSequence &op_seq) -{ - assert(!_return_fn_seq); - assert(_dyn_tensor_manager); - assert(_tensor_reg); - - auto dyn_shape_inferer = - std::make_unique<exec::DynamicShapeInferer>(_graph.operands(), _tensor_reg); - - _return_fn_seq = std::make_unique<exec::FunctionSequence>(); - - // Prepare to handle dynamic tensors later - auto dyn_ctx = std::make_shared<exec::FunctionSequence::DynamicTensorCtx>(); - { - dyn_ctx->op_seq = &op_seq; - dyn_ctx->operations = &_graph.operations(); - dyn_ctx->dynamic_shape_inferer = std::move(dyn_shape_inferer); - dyn_ctx->tensor_registry = _tensor_reg; - dyn_ctx->dynamic_tensor_manager = _dyn_tensor_manager; - - _return_fn_seq->dynamic_tensor_ctx(dyn_ctx); - } - _return_fn_seq->enableDynamicShapeInferer(true); - - for (const auto &op_idx : op_seq.operations()) - { - const auto &node = _graph.operations().at(op_idx); - node.accept(*this); - _return_fn_seq->append(releaseFunction()); - } -} - -void KernelGenerator::visit(const ir::operation::If &node) -{ - const auto then_subg_index = node.param().then_subg_index; - const auto else_subg_index = node.param().else_subg_index; - - std::vector<std::shared_ptr<backend::ITensor>> input_tensors; - for (const auto input_index : node.getInputs()) - { - auto input_tensor = getTensor(input_index); - - input_tensors.emplace_back(input_tensor); - } - - std::vector<std::shared_ptr<backend::ITensor>> output_tensors; - exec::DynAllocInfoMap outputs_dyn_alloc_info; - for (const auto output_index : node.getOutputs()) - { - auto output_tensor = getTensor(output_index); - - output_tensors.emplace_back(output_tensor); - outputs_dyn_alloc_info[output_tensor] = exec::DynAllocInfo{output_index}; - } - - // IfLayer just set ExecutorMap instead of then and else executor to avoid complexity of - // creating executor recusively - const auto cond_tensor = input_tensors.front(); - input_tensors.erase(input_tensors.begin()); - auto fn = std::make_unique<::onert::backend::controlflow::kernel::IfLayer>( - cond_tensor, input_tensors, output_tensors, node.getOutputs(), _graph, outputs_dyn_alloc_info, - then_subg_index, else_subg_index, _executor_map); - - _return_fn = std::move(fn); -} - -void KernelGenerator::visit(const ir::operation::Permute &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(0)}; - - // Add PermuteLayer - std::vector<std::shared_ptr<ITensor>> output_tensors{getTensor(output_index)}; - std::vector<std::shared_ptr<ITensor>> input_tensors{getTensor(input_index)}; - std::unordered_map<std::shared_ptr<ITensor>, exec::DynAllocInfo> outputs_dyn_alloc_info; - outputs_dyn_alloc_info[output_tensors.at(0)] = exec::DynAllocInfo{output_index}; - - auto fn = - std::make_unique<kernel::PermuteLayer>(input_tensors, output_tensors, outputs_dyn_alloc_info); - - _return_fn = std::move(fn); -} - -void KernelGenerator::visit(const ir::operation::While &node) -{ - const auto cond_subg_index = node.param().cond_subg_index; - const auto body_subg_index = node.param().body_subg_index; - - // This op does not support input as a constant, because controlflow backend does not have - // TensorBuilder - std::vector<std::shared_ptr<backend::ITensor>> input_tensors; - for (const auto input_index : node.getInputs()) - { - auto input_tensor = getTensor(input_index); - - input_tensors.emplace_back(input_tensor); - } - - std::vector<std::shared_ptr<backend::ITensor>> output_tensors; - std::unordered_map<std::shared_ptr<ITensor>, exec::DynAllocInfo> outputs_dyn_alloc_info; - for (const auto output_index : node.getOutputs()) - { - auto output_tensor = getTensor(output_index); - - output_tensors.emplace_back(output_tensor); - - outputs_dyn_alloc_info[output_tensor] = exec::DynAllocInfo{output_index}; - } - - // WhileLayer just set ExecutorMap instead of cond and body executor to avoid complexity of - // creating executor recusively - auto fn = std::make_unique<::onert::backend::controlflow::kernel::WhileLayer>( - input_tensors, output_tensors, node.getOutputs(), _graph, outputs_dyn_alloc_info, - cond_subg_index, body_subg_index, _executor_map); - - _return_fn = std::move(fn); -} - -std::shared_ptr<backend::ITensor> KernelGenerator::getTensor(const ir::OperandIndex &index) -{ - std::shared_ptr<backend::ITensor> ret = _tensor_registries.getITensor(index); - assert(ret != nullptr); - return ret; -} - -} // namespace controlflow -} // namespace backend -} // namespace onert diff --git a/runtime/onert/core/src/backend/controlflow/TensorRegistry.h b/runtime/onert/core/src/backend/controlflow/TensorRegistry.h deleted file mode 100644 index 678c5b73b..000000000 --- a/runtime/onert/core/src/backend/controlflow/TensorRegistry.h +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __ONERT_BACKEND_CONTROLFLOW_TENSOR_REGISTRY_H__ -#define __ONERT_BACKEND_CONTROLFLOW_TENSOR_REGISTRY_H__ - -#include "backend/cpu_common/TensorRegistry.h" -#include "backend/ITensorRegistry.h" -#include "Tensor.h" -#include "UserTensor.h" -#include <assert.h> - -namespace onert -{ -namespace backend -{ -namespace controlflow -{ - -/** - * @brief Tensor registry class for controlflow backend - * - * This class contains three types of tensors. Two native tensors(tensors that are managed by this - * backend) and the other is migrant tensor. - * - * - NativeUserTensor - @c UserTensor managed by this backend, buffer is user-given - * - NativeOwnTensor - @c cpu_common::Tensor managed by this backend ( in @c _base_reg ) - * - MigrantTensor - @c IPortableTensor managed by other backends ( in @c _base_reg ) - * - * @note @c _base_reg is used in implementation to reuse @c cpu_common::StaticTensorManager - * - */ -class TensorRegistry : public ITensorRegistry -{ -public: - TensorRegistry() : _base_reg{new cpu_common::TensorRegistry} {} - - std::shared_ptr<ITensor> getITensor(const ir::OperandIndex &ind) override - { - auto base_tensor = _base_reg->getITensor(ind); - if (base_tensor) - return base_tensor; - return getNativeUserTensor(ind); - } - - std::shared_ptr<ITensor> getNativeITensor(const ir::OperandIndex &ind) override - { - auto base_tensor = _base_reg->getNativeITensor(ind); - if (base_tensor) - return base_tensor; - return getNativeUserTensor(ind); - } - - std::shared_ptr<IPortableTensor> getPortableTensor(const ir::OperandIndex &ind) - { - auto base_tensor = _base_reg->getPortableTensor(ind); - if (base_tensor) - return base_tensor; - return getNativeUserTensor(ind); - } - - std::shared_ptr<IPortableTensor> getNativeTensor(const ir::OperandIndex &ind) - { - auto base_tensor = _base_reg->getNativeTensor(ind); - if (base_tensor) - return base_tensor; - return getNativeUserTensor(ind); - } - - std::shared_ptr<Tensor> getNativeOwnTensor(const ir::OperandIndex &ind) - { - return _base_reg->getNativeTensor(ind); - } - - std::shared_ptr<UserTensor> getNativeUserTensor(const ir::OperandIndex &ind) - { - auto tensor = _native_user_tensors.find(ind); - if (tensor != _native_user_tensors.end()) - return tensor->second; - return nullptr; - } - - bool setMigrantTensor(const ir::OperandIndex &ind, - const std::shared_ptr<IPortableTensor> &tensor) override - { - assert(tensor); - assert(!getITensor(ind)); // For the ind, tensor is not registered yet - _base_reg->setMigrantTensor(ind, tensor); - return true; - } - - void setNativeOwnTensor(ir::OperandIndex ind, const std::shared_ptr<Tensor> &tensor) - { - assert(tensor); - assert(!getITensor(ind)); // For the ind, tensor is not registered yet - _base_reg->setNativeTensor(ind, tensor); - } - - void setNativeUserTensor(ir::OperandIndex ind, const std::shared_ptr<UserTensor> &tensor) - { - assert(tensor); - assert(!getITensor(ind)); // For the ind, tensor is not registered yet - _native_user_tensors[ind] = tensor; - } - - const ir::OperandIndexMap<std::shared_ptr<UserTensor>> &native_user_tensors() - { - return _native_user_tensors; - } - std::shared_ptr<cpu_common::TensorRegistry> base_reg() { return _base_reg; } - -private: - std::shared_ptr<cpu_common::TensorRegistry> _base_reg; - ir::OperandIndexMap<std::shared_ptr<UserTensor>> _native_user_tensors; -}; - -} // namespace controlflow -} // namespace backend -} // namespace onert - -#endif // ifndef __ONERT_BACKEND_CONTROLFLOW_TENSOR_REGISTRY_H__ diff --git a/runtime/onert/core/src/backend/controlflow/UserTensor.h b/runtime/onert/core/src/backend/controlflow/UserTensor.h deleted file mode 100644 index 9be33595d..000000000 --- a/runtime/onert/core/src/backend/controlflow/UserTensor.h +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __ONERT_BACKEND_CONTROLFLOW_USER_TENSOR_H__ -#define __ONERT_BACKEND_CONTROLFLOW_USER_TENSOR_H__ - -#include "ir/OperandInfo.h" -#include "backend/IPortableTensor.h" - -namespace onert -{ -namespace backend -{ -namespace controlflow -{ - -/** - * @brief Tensor object that is for Input and Output tensors from the user. - * - * This class is a wrapped buffer that is allocated by the user. So it does not have resposibility - * on allocation nor deallocation. All the model input/output tensors are wrapped with this class - * for execution. - * - */ -class UserTensor : public IPortableTensor -{ -public: - UserTensor(const ir::OperandInfo &info, ir::Layout layout, uint8_t *buffer, size_t size, - IDynamicTensorManager *dynamic_tensor_manager) - : _info{info}, _layout{layout}, _buffer{buffer}, _size{size}, _dynamic{false}, - _dynamic_tensor_manager{dynamic_tensor_manager} - { - } - - UserTensor(const ir::OperandInfo &info, ir::Layout layout, - IDynamicTensorManager *dynamic_tensor_manager) - : UserTensor{info, layout, nullptr, 0, dynamic_tensor_manager} - { - } - -public: - void setBuffer(uint8_t *buffer, size_t size) - { - _buffer = buffer; - _size = size; - } - -public: - uint8_t *buffer() const override { return _buffer; } - size_t total_size() const override { return _size; } - size_t dimension(size_t index) const override { return _info.shape().dim(index); } - size_t num_dimensions() const override { return _info.shape().rank(); } - size_t calcOffset(const ir::Coordinates &coords) const override; - ir::Layout layout() const override { return _layout; } - ir::DataType data_type() const override { return _info.typeInfo().type(); } - float data_scale() const override { return _info.typeInfo().scale(); } - int32_t data_offset() const override { return _info.typeInfo().offset(); } - bool is_dynamic() const override { return _dynamic; } - void set_dynamic() override { _dynamic = true; } - ir::Shape getShape() const override { return _info.shape(); } - void setShape(const ir::Shape &new_shape) override { _info.shape(new_shape); } - bool is_constant() const override { return false; } - IDynamicTensorManager *dynamic_tensor_manager() override { return _dynamic_tensor_manager; } - -private: - ir::OperandInfo _info; - ir::Layout _layout; - uint8_t *_buffer; - size_t _size; - bool _dynamic; - IDynamicTensorManager *_dynamic_tensor_manager; -}; - -} // namespace controlflow -} // namespace backend -} // namespace onert - -#endif // __ONERT_BACKEND_CONTROLFLOW_USER_TENSOR_H__ diff --git a/runtime/onert/core/src/backend/controlflow/kernel/IfLayer.cc b/runtime/onert/core/src/backend/controlflow/kernel/IfLayer.cc deleted file mode 100644 index 8377c7183..000000000 --- a/runtime/onert/core/src/backend/controlflow/kernel/IfLayer.cc +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "IfLayer.h" - -#include <backend/ITensor.h> -#include "exec/ExecutorBase.h" -#include <misc/polymorphic_downcast.h> -#include "PermuteLayer.h" - -namespace onert -{ -namespace backend -{ -namespace controlflow -{ -namespace kernel -{ - -IfLayer::IfLayer(const std::shared_ptr<backend::ITensor> &cond_tensor, - const std::vector<std::shared_ptr<backend::ITensor>> input_tensors, - const std::vector<std::shared_ptr<backend::ITensor>> output_tensors, - const ir::OperandIndexSequence &output_indices, const ir::Graph &graph, - const exec::DynAllocInfoMap &outputs_dyn_alloc_info, - const ir::SubgraphIndex &then_subg_index, const ir::SubgraphIndex &else_subg_index, - exec::ExecutorMap *executor_map) - : _cond_tensor{cond_tensor}, _input_tensors{input_tensors}, _output_tensors{output_tensors}, - _output_indices{output_indices}, _graph{graph}, - _outputs_dyn_alloc_info{outputs_dyn_alloc_info}, _then_subg_index{then_subg_index}, - _else_subg_index{else_subg_index}, _executor_map{executor_map} -{ - // At this point, executor_map may not have executors of then subg and else subg -} - -void IfLayer::run() -{ - // Check condition - // // If true - // // // Copy _input_tensors -> then subg's inputs - // // // Run then subg - // // // Copy outputs of then subg -> _output_tensors - // // Else - // // // Copy _input_tensors -> else subg's inputs if false - // // // Run else subg - // // // Copy outputs of else subg -> _output_tensors - auto getResultCond = [](backend::ITensor *tensor) -> bool { - bool ret = false; - tensor->access([&](ITensor &tensor) { ret = *reinterpret_cast<bool *>(tensor.buffer()); }); - return ret; - }; - - exec::ExecutorBase *subg_exec = nullptr; - if (getResultCond(_cond_tensor.get())) - { - subg_exec = nnfw::misc::polymorphic_downcast<exec::ExecutorBase *>( - _executor_map->at(_then_subg_index).get()); - } - else - { - subg_exec = nnfw::misc::polymorphic_downcast<exec::ExecutorBase *>( - _executor_map->at(_else_subg_index).get()); - } - - const auto &subg_graph = subg_exec->graph(); - - std::vector<std::shared_ptr<backend::ITensor>> src_tensors; - std::vector<std::shared_ptr<backend::ITensor>> dst_tensors; - // Add tensors used in subgraph or contained in outputs of subgraph - assert(subg_graph.getInputs().size() == _input_tensors.size()); - assert(subg_graph.getInputs().size() == subg_exec->getInputTensors().size()); - for (uint32_t i = 0; i < subg_graph.getInputs().size(); ++i) - { - const auto &subg_input_index = subg_graph.getInputs().at(i); - const auto &subg_input = subg_graph.operands().at(subg_input_index); - if (subg_input.getUses().size() > 0 || subg_graph.getOutputs().contains(subg_input_index)) - { - src_tensors.emplace_back(_input_tensors.at(i)); - dst_tensors.emplace_back(subg_exec->getInputTensors().at(i)); - } - } - const auto &subg_inputs_dyn_alloc_info = subg_exec->getInputsDynamicAllocInfo(); - const auto permute_op_input_to_subg_input = - std::make_shared<PermuteLayer>(src_tensors, dst_tensors, subg_inputs_dyn_alloc_info); - - // Add tensors used as output of operation or contained in outputs of operation - src_tensors.clear(); - dst_tensors.clear(); - assert(_output_indices.size() == subg_exec->getOutputTensors().size()); - assert(_output_indices.size() == _output_tensors.size()); - for (uint32_t i = 0; i < _output_indices.size(); ++i) - { - const auto &output_index = _output_indices.at(i); - const auto &output = _graph.operands().at(output_index); - if (output.getUses().size() > 0 || _graph.getOutputs().contains(output_index)) - { - src_tensors.emplace_back(subg_exec->getOutputTensors().at(i)); - dst_tensors.emplace_back(_output_tensors.at(i)); - } - } - const auto permute_subg_output_to_op_output = - std::make_shared<PermuteLayer>(src_tensors, dst_tensors, _outputs_dyn_alloc_info); - - // Remove copying of unused tensor - permute_op_input_to_subg_input->prepare(); - permute_subg_output_to_op_output->prepare(); - - // Copy & run - subg_exec->execute(_input_tensors, permute_op_input_to_subg_input); - permute_subg_output_to_op_output->run(); -} - -} // namespace kernel -} // namespace controlflow -} // namespace backend -} // namespace onert diff --git a/runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.cc b/runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.cc deleted file mode 100644 index e8f1ea679..000000000 --- a/runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "PermuteLayer.h" - -#include "exec/ShapeConverter.h" - -namespace onert -{ -namespace backend -{ -namespace controlflow -{ -namespace kernel -{ - -void PermuteLayer::run() -{ - assert(_src_tensors.size() == _dst_tensors.size()); - // PermuteLayer infers dynamic shape inside itself whenever run is called for the following - // reasons: - // 1. PermuteLayer has to access dynamic tensor manager for input/output tensors of other backends - // 2. Other controlflow operation(If/While) uses this layout for copying tensors of other - // subgraphs(with other backends) - // 3. This infering code is placed here to avoid duplicated code that can be caused by above 2 - // reasons - - // check if output is not dynamic - for (size_t i = 0; i < _src_tensors.size(); ++i) - { - auto dst_tensor = _dst_tensors.at(i); - auto src_tensor = _src_tensors.at(i); - if (src_tensor->is_dynamic() || dst_tensor->is_dynamic()) - { - // getting output shape - auto src_shape = src_tensor->getShape(); - - // set output shape and output buffer - ir::Shape new_shape = - exec::convertShape(src_shape, src_tensor->layout(), dst_tensor->layout()); - - try - { - const auto dst_index = _dst_dyn_alloc_info_map.at(dst_tensor).ind; - auto dyn_tensor_manager = dst_tensor->dynamic_tensor_manager(); - if (!dyn_tensor_manager) - throw std::runtime_error{ - "Error: PermuteLayer: output's TensorManager does not support dynamic tensor"}; - dyn_tensor_manager->applyShape(dst_index, new_shape); - assert(dst_tensor->buffer() != nullptr); - } - catch (const std::out_of_range &e) - { - std::cerr << "Error: out_of_range in PermuteLayer: output's TensorManager does not support " - "dynamic tensor" - << '\n'; - throw; - } - } - assert(exec::convertShape(src_tensor->getShape(), src_tensor->layout(), dst_tensor->layout()) == - dst_tensor->getShape()); - } - IPermuteFunction::run(); -} - -} // namespace kernel -} // namespace controlflow -} // namespace backend -} // namespace onert diff --git a/runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.h b/runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.h deleted file mode 100644 index 403ac770d..000000000 --- a/runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.h +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __ONERT_BACKEND_CONTROLFLOW_KERNEL_PERMUTELAYER_H__ -#define __ONERT_BACKEND_CONTROLFLOW_KERNEL_PERMUTELAYER_H__ - -#include "backend/ITensorBuilder.h" -#include "exec/IPermuteFunction.h" -#include "exec/IExecutor.h" - -namespace onert -{ -namespace backend -{ -namespace controlflow -{ -namespace kernel -{ - -class PermuteLayer : public onert::exec::IPermuteFunction -{ -public: - PermuteLayer(const std::vector<std::shared_ptr<ITensor>> &src_tensors, - const std::vector<std::shared_ptr<ITensor>> &dst_tensors, - const exec::DynAllocInfoMap &dst_dyn_alloc_info_map) - : _dst_dyn_alloc_info_map{dst_dyn_alloc_info_map} - { - assert(src_tensors.size() == dst_tensors.size()); - _src_tensors = src_tensors; - _dst_tensors = dst_tensors; - } - - void optimize() override - { - // Remove copying of tensor as nullptr - auto src_it = _src_tensors.begin(); - auto dst_it = _dst_tensors.begin(); - while (src_it != _src_tensors.end()) - { - if ((*src_it == *dst_it) || (*src_it == nullptr || *dst_it == nullptr)) - { - src_it = _src_tensors.erase(src_it); - dst_it = _dst_tensors.erase(dst_it); - } - else - { - ++src_it; - ++dst_it; - } - } - } - - void run() override; - -private: - const exec::DynAllocInfoMap _dst_dyn_alloc_info_map; -}; - -} // namespace kernel -} // namespace controlflow -} // namespace backend -} // namespace onert - -#endif // __ONERT_BACKEND_CONTROLFLOW_KERNEL_PERMUTELAYER_H__ diff --git a/runtime/onert/core/src/backend/controlflow/kernel/WhileLayer.cc b/runtime/onert/core/src/backend/controlflow/kernel/WhileLayer.cc deleted file mode 100644 index 50936e5f6..000000000 --- a/runtime/onert/core/src/backend/controlflow/kernel/WhileLayer.cc +++ /dev/null @@ -1,216 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "WhileLayer.h" - -#include <backend/ITensor.h> -#include "exec/ExecutorBase.h" -#include <misc/polymorphic_downcast.h> -#include "PermuteLayer.h" - -namespace onert -{ -namespace backend -{ -namespace controlflow -{ -namespace kernel -{ - -WhileLayer::WhileLayer(const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors, - const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors, - const ir::OperandIndexSequence &output_indices, const ir::Graph &graph, - const exec::DynAllocInfoMap &outputs_dyn_alloc_info, - const ir::SubgraphIndex &cond_subg_index, - const ir::SubgraphIndex &body_subg_index, exec::ExecutorMap *executor_map) - : _cond_subg_index{cond_subg_index}, _body_subg_index{body_subg_index}, - _output_indices{output_indices}, _graph{graph}, _input_tensors{input_tensors}, - _output_tensors{output_tensors}, _outputs_dyn_alloc_info{outputs_dyn_alloc_info}, - _executor_map{executor_map} -{ - // At this point, executor_map may not have executors of cond subg and body subg -} - -void WhileLayer::run() -{ - // Copy "_input_tensors" -> "cond subg inputs" - // Run cond subg - // Start loop while output of cond subg is ture - // // Copy "_input_tensors" -> "body subg inputs" in the first iteration, then copy "body subg - // outputs" -> "body subg inputs" in the second or more iterations - // // Run body subg - // // Copy "body subg outputs" -> "cond subg inputs" - // // Run cond subg - // If there is no loop copy "_input_tensors" -> "_dst_tensors", else copy "cond subg inputs" -> - // "_dst_tensors" - auto cond_exec = nnfw::misc::polymorphic_downcast<exec::ExecutorBase *>( - _executor_map->at(_cond_subg_index).get()); - auto body_exec = nnfw::misc::polymorphic_downcast<exec::ExecutorBase *>( - _executor_map->at(_body_subg_index).get()); - - const auto &cond_graph = cond_exec->graph(); - const auto &cond_inputs_dyn_alloc = cond_exec->getInputsDynamicAllocInfo(); - const auto &body_graph = body_exec->graph(); - const auto &body_inputs_dyn_alloc = body_exec->getInputsDynamicAllocInfo(); - - std::vector<std::shared_ptr<backend::ITensor>> input_tensors; - std::vector<std::shared_ptr<backend::ITensor>> cond_input_tensors; - std::vector<std::shared_ptr<backend::ITensor>> body_input_tensors; - std::vector<std::shared_ptr<backend::ITensor>> body_output_tensors; - std::vector<std::shared_ptr<backend::ITensor>> output_tensors; - - // Add only used tensors in cond subgraph - assert(cond_graph.getInputs().size() == _input_tensors.size()); - assert(cond_graph.getInputs().size() == cond_exec->getInputTensors().size()); - for (uint32_t i = 0; i < cond_graph.getInputs().size(); ++i) - { - const auto &cond_input = cond_graph.operands().at(cond_graph.getInputs().at(i)); - if (cond_input.getUses().size() > 0) - { - input_tensors.emplace_back(_input_tensors.at(i)); - cond_input_tensors.emplace_back(cond_exec->getInputTensors().at(i)); - } - } - const auto permute_op_input_to_cond_input = - std::make_shared<PermuteLayer>(input_tensors, cond_input_tensors, cond_inputs_dyn_alloc); - - // Add only used tensors among outputs of while operation - assert(_output_indices.size() == _input_tensors.size()); - assert(_output_indices.size() == _output_tensors.size()); - input_tensors.clear(); - output_tensors.clear(); - for (size_t i = 0; i < _output_indices.size(); ++i) - { - const auto &output_index = _output_indices.at(i); - const auto &output = _graph.operands().at(output_index); - if (output.getUses().size() > 0 || _graph.getOutputs().contains(output_index)) - { - input_tensors.emplace_back(_input_tensors.at(i)); - output_tensors.emplace_back(_output_tensors.at(i)); - } - } - const auto permute_op_input_to_op_output = - std::make_shared<PermuteLayer>(input_tensors, output_tensors, _outputs_dyn_alloc_info); - - // Add all tensors with unused tensors in body subgraph because unused input tensors will be - // copied output tensors in body subgraph - assert(_input_tensors.size() == body_exec->getInputTensors().size()); - input_tensors = _input_tensors; - body_input_tensors = body_exec->getInputTensors(); - const auto permute_op_input_to_body_input = - std::make_shared<PermuteLayer>(input_tensors, body_input_tensors, body_inputs_dyn_alloc); - - // Add only used tensors in cond subgraph - assert(cond_graph.getInputs().size() == body_exec->getOutputTensors().size()); - assert(cond_graph.getInputs().size() == cond_exec->getInputTensors().size()); - body_output_tensors.clear(); - cond_input_tensors.clear(); - for (uint32_t i = 0; i < cond_graph.getInputs().size(); ++i) - { - const auto &cond_input = cond_graph.operands().at(cond_graph.getInputs().at(i)); - if (cond_input.getUses().size() > 0) - { - body_output_tensors.emplace_back(body_exec->getOutputTensors().at(i)); - cond_input_tensors.emplace_back(cond_exec->getInputTensors().at(i)); - } - } - const auto permute_body_output_to_cond_input = std::make_shared<PermuteLayer>( - body_output_tensors, cond_input_tensors, cond_inputs_dyn_alloc); - - // Add only used tensors in body subgraph - assert(body_graph.getInputs().size() == body_exec->getOutputTensors().size()); - assert(body_graph.getInputs().size() == body_exec->getInputTensors().size()); - body_output_tensors.clear(); - body_input_tensors.clear(); - for (uint32_t i = 0; i < body_graph.getInputs().size(); ++i) - { - const auto &body_input_index = body_graph.getInputs().at(i); - const auto &body_input = body_graph.operands().at(body_input_index); - if (body_input.getUses().size() > 0 && - !body_exec->graph().getOutputs().contains(body_input_index)) - { - body_output_tensors.emplace_back(body_exec->getOutputTensors().at(i)); - body_input_tensors.emplace_back(body_exec->getInputTensors().at(i)); - } - } - const auto permute_body_output_to_body_input = std::make_shared<PermuteLayer>( - body_output_tensors, body_input_tensors, body_inputs_dyn_alloc); - - // Add only used tensors among outputs of while operation - assert(_output_indices.size() == body_exec->getOutputTensors().size()); - assert(_output_indices.size() == _output_tensors.size()); - body_output_tensors.clear(); - output_tensors.clear(); - for (size_t i = 0; i < _output_indices.size(); ++i) - { - const auto &output_index = _output_indices.at(i); - const auto &output = _graph.operands().at(output_index); - if (output.getUses().size() > 0 || _graph.getOutputs().contains(output_index)) - { - body_output_tensors.emplace_back(body_exec->getOutputTensors().at(i)); - output_tensors.emplace_back(_output_tensors.at(i)); - } - } - const auto permute_body_output_to_op_output = - std::make_shared<PermuteLayer>(body_output_tensors, output_tensors, _outputs_dyn_alloc_info); - - // Remove copying of unused tensor - permute_op_input_to_cond_input->prepare(); - permute_op_input_to_op_output->prepare(); - permute_op_input_to_body_input->prepare(); - permute_body_output_to_cond_input->prepare(); - permute_body_output_to_body_input->prepare(); - permute_body_output_to_op_output->prepare(); - - cond_exec->execute(_input_tensors, permute_op_input_to_cond_input); - - assert(cond_exec->getOutputTensors().size() == 1); - auto &cond_output_tensor = cond_exec->getOutputTensors().at(0); - auto getResultCond = [](backend::ITensor *tensor) -> bool { - bool ret = false; - tensor->access([&](ITensor &tensor) { ret = *reinterpret_cast<bool *>(tensor.buffer()); }); - return ret; - }; - - const auto body_execute_with_op_inputs = [&]() { - body_exec->execute(_input_tensors, permute_op_input_to_body_input); - }; - - const auto body_execute_with_body_outputs = [&]() { - body_exec->execute(body_exec->getOutputTensors(), permute_body_output_to_body_input); - }; - - std::function<void()> body_execute = body_execute_with_op_inputs; - const auto cond_execute = [&]() { - cond_exec->execute(body_exec->getOutputTensors(), permute_body_output_to_cond_input); - }; - auto permute_to_outputs_fn = permute_op_input_to_op_output; - - // Loop while Cond subgraph's output is true - while (getResultCond(cond_output_tensor.get())) - { - body_execute(); - cond_execute(); - body_execute = body_execute_with_body_outputs; - permute_to_outputs_fn = permute_body_output_to_op_output; - } - permute_to_outputs_fn->run(); -} - -} // namespace kernel -} // namespace controlflow -} // namespace backend -} // namespace onert diff --git a/runtime/onert/core/src/backend/cpu_common/DynamicTensorManager.cc b/runtime/onert/core/src/backend/cpu_common/DynamicTensorManager.cc deleted file mode 100644 index f7ce3d011..000000000 --- a/runtime/onert/core/src/backend/cpu_common/DynamicTensorManager.cc +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/cpu_common/DynamicTensorManager.h" - -#include "util/logging.h" - -namespace onert -{ -namespace backend -{ -namespace cpu_common -{ - -DynamicTensorManager::DynamicTensorManager(const std::shared_ptr<TensorRegistry> ®) - : _dynamic_mem_mgr{new DynamicMemoryManager()}, _tensors{reg} -{ - // DO NOTHING -} - -void DynamicTensorManager::applyShape(const ir::OperandIndex &ind, const ir::Shape &new_shape) -{ - VERBOSE_F() << ind << std::endl; - - auto tensor = _tensors->getNativeTensor(ind); - assert(tensor); - - bool previously_dynamic = tensor->is_dynamic(); - - auto allocTensorMem = [&](bool overwrite = false) { - auto capacity = tensor->total_size(); - auto alloc = _dynamic_mem_mgr->allocate(ind, capacity); - - if (overwrite) - tensor->overwriteBuffer(alloc); - else - tensor->setBuffer(alloc); - }; - - if (!previously_dynamic) - { - // TODO deallocate tensor->buffer() - // issue is that staticTensorManager might have allocate this memory - tensor->setShape(new_shape); - tensor->set_dynamic(); - allocTensorMem(true); - } - else if (tensor->buffer() == nullptr) - { - tensor->setShape(new_shape); - tensor->set_dynamic(); - allocTensorMem(); - } - // when buffer was already allocated and new_shape requires different size - else - { - auto previous_size = tensor->total_size(); - auto new_size = new_shape.num_elements() * sizeOfDataType(tensor->data_type()); - if (previous_size != new_size) - { - _dynamic_mem_mgr->deallocate(ind); - - tensor->setShape(new_shape); - tensor->set_dynamic(); - allocTensorMem(true); - } - else - { // when buffer with same size was already allocated, shape could differ - tensor->setShape(new_shape); - } - } -} - -void DynamicTensorManager::buildTensor(const ir::OperandIndex &ind, - const ir::OperandInfo &tensor_info, - ir::Layout backend_layout) -{ - assert(_tensors->getNativeTensor(ind) == nullptr); - auto tensor = std::make_shared<Tensor>(tensor_info, backend_layout, this); - _tensors->setNativeTensor(ind, tensor); -} - -void DynamicTensorManager::planDealloc(ir::OperationIndex op_ind, ir::OperandIndex operand_ind) -{ - _dealloc_tensor_map[op_ind].emplace(operand_ind); -} - -void DynamicTensorManager::deallocInput(ir::OperationIndex op_ind) -{ - auto find = _dealloc_tensor_map.find(op_ind); - if (find == _dealloc_tensor_map.end()) - return; - - auto &input_set = find->second; - for (auto input_ind : input_set) - { - auto *tensor = _tensors->getNativeTensor(input_ind).get(); - if (!tensor->is_dynamic()) - continue; - - _dynamic_mem_mgr->deallocate(input_ind); - tensor->resetBuffer(); - - VERBOSE(DynamicTensorManager) << "Deallocating #" << input_ind.value() - << " (input of op_ind: " << op_ind.value() << ")" << std::endl; - } -} - -void DynamicTensorManager::deallocSubgraphOutput(ir::OperandIndex output_ind) -{ - auto *tensor = _tensors->getNativeTensor(output_ind).get(); - if (!tensor->is_dynamic()) - return; - - _dynamic_mem_mgr->deallocate(output_ind); - tensor->resetBuffer(); - - VERBOSE(DynamicTensorManager) << "Deallocating #" << output_ind.value() - << " (output of a subgraph)" << std::endl; -} - -} // namespace cpu_common -} // namespace backend -} // namespace onert diff --git a/runtime/onert/core/src/compiler/BackendManager.cc b/runtime/onert/core/src/compiler/BackendManager.cc index db7a14a96..44442c065 100644 --- a/runtime/onert/core/src/compiler/BackendManager.cc +++ b/runtime/onert/core/src/compiler/BackendManager.cc @@ -16,22 +16,17 @@ #include "compiler/BackendManager.h" -#include <memory> -#include <dlfcn.h> +#include "../backend/builtin/Backend.h" +#include "../backend/builtin/Config.h" -#include "backend/Backend.h" -#include "backend/controlflow/Backend.h" -#include "backend/controlflow/Config.h" -#include "backend/IConfig.h" -#include "util/logging.h" -#include "util/ConfigSource.h" -#include "misc/string_helpers.h" +#include <dlfcn.h> +#include <memory> static const char *SHARED_LIB_EXT = #if defined(__APPLE__) && defined(__MACH__) - ".dylib"; + ".dylib"; #else - ".so"; + ".so"; #endif namespace onert @@ -45,20 +40,20 @@ BackendManager &BackendManager::get() return object; } -BackendManager::BackendManager() { loadControlflowBackend(); } +BackendManager::BackendManager() { loadBuiltinBackend(); } -void BackendManager::loadControlflowBackend() +void BackendManager::loadBuiltinBackend() { - auto backend_object = std::unique_ptr<backend::controlflow::Backend, backend_destroy_t>( - new backend::controlflow::Backend, [](backend::Backend *backend) { delete backend; }); + auto backend_object = std::unique_ptr<backend::builtin::Backend, backend_destroy_t>( + new backend::builtin::Backend, [](backend::Backend *backend) { delete backend; }); bool initialized = backend_object->config()->initialize(); // Call initialize here? if (!initialized) { - throw std::runtime_error(backend::controlflow::Config::ID + " backend initialization failed"); + throw std::runtime_error(backend::builtin::Config::ID + " backend initialization failed"); } - _controlflow = backend_object.get(); // Save the controlflow backend implementation pointer - assert(_controlflow); + _builtin = backend_object.get(); // Save the builtin backend implementation pointer + assert(_builtin); _gen_map.emplace(backend_object->config()->id(), std::move(backend_object)); } @@ -69,68 +64,67 @@ void BackendManager::loadBackend(const std::string &backend) return; } - // TODO Remove indentation - // Workaround If backend have dynamic library with "-boost" suffix naming, - // BackendManager load library with "-boost" suffix instead of library without suffix - // This feature is used for custom backend extension to support additional operations - { - const std::string backend_boost_so = "libbackend_" + backend + "-boost" + SHARED_LIB_EXT; - const std::string backend_so = "libbackend_" + backend + SHARED_LIB_EXT; + const std::string backend_so = "libbackend_" + backend + SHARED_LIB_EXT; + void *handle = dlopen(backend_so.c_str(), RTLD_LAZY | RTLD_LOCAL); - void *handle = dlopen(backend_boost_so.c_str(), RTLD_LAZY | RTLD_LOCAL); - if (handle == nullptr) - { - handle = dlopen(backend_so.c_str(), RTLD_LAZY | RTLD_LOCAL); + if (handle == nullptr) + { + VERBOSE(BackendManager) << "Failed to load backend '" << backend << "' - " << dlerror() << "\n"; + return; + } - if (handle == nullptr) - { - VERBOSE_F() << "Failed to load backend '" << backend << "' - " << dlerror() << std::endl; - return; - } + VERBOSE(BackendManager) << "Successfully loaded '" << backend << "'(" << backend_so << ")\n"; - VERBOSE_F() << "Successfully loaded '" << backend << "' - " << backend_so << "\n"; + { + // load object creator function + auto backend_create = (backend_create_t)dlsym(handle, "onert_backend_create"); + if (backend_create == nullptr) + { + // TODO replace `fprintf` with `VERBOSE` + fprintf(stderr, "BackendManager: unable to find function `onert_backend_create` : %s\n", + dlerror()); + dlclose(handle); + return; } - else + + // load object creator function + auto backend_destroy = (backend_destroy_t)dlsym(handle, "onert_backend_destroy"); + if (backend_destroy == nullptr) { - VERBOSE_F() << "Successfully loaded '" << backend << "' - " << backend_boost_so << "\n"; + // TODO replace `fprintf` with `VERBOSE` + fprintf(stderr, "BackendManager: unable to find `function onert_backend_destroy` : %s\n", + dlerror()); + dlclose(handle); + return; } + auto backend_object = + std::unique_ptr<backend::Backend, backend_destroy_t>(backend_create(), backend_destroy); + bool initialized = backend_object->config()->initialize(); // Call initialize here? + if (!initialized) { - // load object creator function - auto backend_create = (backend_create_t)dlsym(handle, "onert_backend_create"); - if (backend_create == nullptr) - { - fprintf(stderr, "BackendManager: unable to open function onert_backend_create : %s\n", - dlerror()); - abort(); - } + VERBOSE(BackendManager) << backend.c_str() + << " backend initialization failed. Don't use this backend" + << std::endl; + dlclose(handle); + return; + } + _gen_map.emplace(backend_object->config()->id(), std::move(backend_object)); + } - // load object creator function - auto backend_destroy = (backend_destroy_t)dlsym(handle, "onert_backend_destroy"); - if (backend_destroy == nullptr) + // Save backend handle (avoid warning by handle lost without dlclose()) + auto u_handle = std::unique_ptr<void, dlhandle_destroy_t>{ + handle, [id = backend, filename = backend_so](void *h) { + if (dlclose(h) == 0) { - fprintf(stderr, "BackendManager: unable to open function onert_backend_destroy : %s\n", - dlerror()); - abort(); + VERBOSE(BackendManager) << "Successfully unloaded '" << id << "'(" << filename << ")\n"; } - - auto backend_object = - std::unique_ptr<backend::Backend, backend_destroy_t>(backend_create(), backend_destroy); - bool initialized = backend_object->config()->initialize(); // Call initialize here? - if (!initialized) + else { - VERBOSE_F() << backend.c_str() << " backend initialization failed. Don't use this backend" - << std::endl; - dlclose(handle); - return; + VERBOSE(BackendManager) << "Failed to unload backend '" << id << "'- " << dlerror() << "\n"; } - _gen_map.emplace(backend_object->config()->id(), std::move(backend_object)); - } - - // Save backend handle (avoid warning by handle lost without dlclose()) - auto u_handle = std::unique_ptr<void, dlhandle_destroy_t>{handle, [](void *h) { dlclose(h); }}; - _handle_map.emplace(backend, std::move(u_handle)); - } + }}; + _handle_map.emplace(backend, std::move(u_handle)); } backend::Backend *BackendManager::get(const std::string &key) @@ -153,7 +147,7 @@ const backend::Backend *BackendManager::get(const std::string &key) const return nullptr; } -const backend::controlflow::Backend *BackendManager::getControlflow() const { return _controlflow; } +const backend::Backend *BackendManager::getBuiltin() const { return _builtin; } } // namespace compiler } // namespace onert diff --git a/runtime/onert/core/src/compiler/Compiler.cc b/runtime/onert/core/src/compiler/Compiler.cc index 93dbbc3b5..63667a063 100644 --- a/runtime/onert/core/src/compiler/Compiler.cc +++ b/runtime/onert/core/src/compiler/Compiler.cc @@ -16,284 +16,177 @@ #include "compiler/Compiler.h" -#include "ParamChecker.h" +#include "CompilerHelpers.h" #include "ExecutorFactory.h" -#include "OperationValidator.h" -#include "Fp32ToFp16Converter.h" +#include "ShapeValidator.h" +#include "pass/ConstantOutputPass.h" +#include "pass/OddOutputPass.h" +#include "pass/PassRunner.h" +#include "pass/UnusedOperandEliminationPass.h" +#include "../dumper/dot/DotDumper.h" +#include "../exec/SingleModelExecutors.h" +#include "../ir/OperationDumper.h" +#include "../ir/verifier/Verifier.h" -#include <backend/controlflow/Config.h> -#include "compiler/BackendManager.h" -#include "compiler/IScheduler.h" -#include "compiler/ManualScheduler.h" -#include "compiler/HEScheduler.h" -#include "compiler/StaticShapeInference.h" -#include "exec/ExecTime.h" -#include "ir/operation/LowerInfo.h" -#include "dumper/dot/DotDumper.h" -#include "compiler/Linear.h" -#include "interp/InterpExecutor.h" -#include "util/ConfigSource.h" -#include "util/logging.h" -#include "ir/OperationDumper.h" -#include "misc/string_helpers.h" +#include "compiler/StaticShapeInferer.h" + +#include <misc/string_helpers.h> +#include <misc/polymorphic_downcast.h> namespace onert { - namespace compiler { -CompilerOptions fetchCompilerOptionsFromGlobalConfig(const ir::Subgraphs &subgs) +Compiler::Compiler(const std::shared_ptr<ir::Model> &model, CompilerOptions *copts) + : _model{model}, _options{copts} { - CompilerOptions options; - options.backend_list = nnfw::misc::split(util::getConfigString(util::config::BACKENDS), ';'); - options.is_primary_subgraph = false; - options.trace_filepath = util::getConfigString(util::config::TRACE_FILEPATH); - options.graph_dump_level = util::getConfigInt(util::config::GRAPH_DOT_DUMP); - options.op_seq_max_node = util::getConfigInt(util::config::OP_SEQ_MAX_NODE); - options.executor = util::getConfigString(util::config::EXECUTOR); - options.he_scheduler = util::getConfigBool(util::config::USE_SCHEDULER); - options.he_profiling_mode = util::getConfigBool(util::config::PROFILING_MODE); - options.disable_compile = util::getConfigBool(util::config::DISABLE_COMPILE); - options.fp16_enable = util::getConfigBool(util::config::FP16_ENABLE); -#ifdef RUY_PROFILER - options.op_seq_max_node = 1; -#endif - - { - // Backend for all - auto &ms_options = options.manual_scheduler_options; - - // Default value for op_backend_all is first element in the backend list - ms_options.backend_for_all = util::getConfigString(util::config::OP_BACKEND_ALLOPS); - -// Opcode to Backend -#define OP(OpName) \ - { \ - const auto &backend_str = util::getConfigString(util::config::OP_BACKEND_##OpName); \ - if (!backend_str.empty()) \ - { \ - ms_options.opcode_to_backend[ir::OpCode::OpName] = backend_str; \ - } \ - } -#include "ir/Operations.lst" -#undef OP - - // Index to Backend - // TODO Support multiple subgraphs for manual scheduling - auto map_str = util::getConfigString(util::config::OP_BACKEND_MAP); - auto key_val_list = nnfw::misc::split(map_str, ';'); - for (const auto &key_val_str : key_val_list) - { - if (key_val_str.empty()) - { - continue; - } - - auto key_val = nnfw::misc::split(key_val_str, '='); - const auto &key_str = key_val.at(0); - const auto &val = key_val.at(1); - auto key = static_cast<uint32_t>(std::stoi(key_str)); - - subgs.at(ir::SubgraphIndex{0}) - ->operations() - .at(ir::OperationIndex{key}); // Check if exist, or this wil throw - ms_options.index_to_backend.emplace(ir::OperationIndex{key}, val); - } - } - return options; + // DO NOTHING } -Compiler::Compiler(const std::shared_ptr<ir::Subgraphs> &subgs) - : _subgraphs{subgs}, _state{State::CREATED} +Compiler::Compiler(const std::shared_ptr<ir::NNPkg> &nnpkg, CompilerOptions *copts) + : _model{nnpkg->primary_model()}, _options{copts} { - // Set default values for CompilerOptions - // All these default values should not be fetched from Env, when we stop supporting Android NN - // API. - _options = fetchCompilerOptionsFromGlobalConfig(*subgs); + // Use for single model only + assert(nnpkg->model_count() == 1); } -void Compiler::enableToFp16() { _options.fp16_enable = true; } - -void Compiler::checkProfilerConditions() +std::shared_ptr<CompilerArtifact> Compiler::compile(void) { - if (!_options.he_scheduler) - throw std::runtime_error("Heterogeneous scheduler must be enabled during profiling."); - - if (_options.executor != "Dataflow") - throw std::runtime_error("Profiling mode works only with 'Dataflow' executor"); -} + /*************************************************** + * Prepare compilation phase + ***************************************************/ + if (!_options) + throw std::runtime_error{"Empty compile option"}; -std::shared_ptr<exec::ExecutorMap> Compiler::compile(void) -{ - // Set control flow backend for control flow operators + // Mode check + // TODO handle option for each model + if (_options->he_profiling_mode) { - _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::If] = - backend::controlflow::Config::ID; - _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::While] = - backend::controlflow::Config::ID; - } + if (!_options->he_scheduler) + throw std::runtime_error("Heterogeneous scheduler must be enabled during profiling."); - // FIXME This is a workaround for bcq operations, should remove it - { - _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQFullyConnected] = "bcq"; - _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQGather] = "bcq"; + if (_options->executor != "Dataflow") + throw std::runtime_error("Profiling mode works only with 'Dataflow' executor"); } + if (!_model->hasOnly<ir::Graph>()) { - VERBOSE(Compiler) << std::boolalpha; - VERBOSE(Compiler) << "==== Compiler Options ====" << std::endl; - VERBOSE(Compiler) << "backend_list : " - << nnfw::misc::join(_options.backend_list.begin(), - _options.backend_list.end(), "/") - << std::endl; - VERBOSE(Compiler) << "trace_filepath : " << _options.trace_filepath << std::endl; - VERBOSE(Compiler) << "graph_dump_level : " << _options.graph_dump_level << std::endl; - VERBOSE(Compiler) << "op_seq_max_node : " << _options.op_seq_max_node << std::endl; - VERBOSE(Compiler) << "executor : " << _options.executor << std::endl; - VERBOSE(Compiler) << "manual_scheduler_options : (Too many things to print)" << std::endl; - VERBOSE(Compiler) << "he_scheduler : " << _options.he_scheduler << std::endl; - VERBOSE(Compiler) << "he_profiling_mode : " << _options.he_profiling_mode << std::endl; - VERBOSE(Compiler) << "disable_compile : " << _options.disable_compile << std::endl; - VERBOSE(Compiler) << "fp16_enable : " << _options.fp16_enable << std::endl; - VERBOSE(Compiler) << std::noboolalpha; + throw std::runtime_error("Compiler can only compile models for inference."); } - /*************************************************** - * Prepare compilation phase - ***************************************************/ + _options->forceInternalOptions(); + _options->verboseOptions(); - auto executors = std::make_shared<exec::ExecutorMap>(); + auto custom_kernel_builder = _model->getKernelBuilder(); - // Compilable check - // TODO: Support hybrid execution - - // execution between interpreter and compiled executor (including control flow) - if (!checkCompilable()) - { - _subgraphs->iterate([&](const ir::SubgraphIndex &index, ir::Graph &subg) { - executors->emplace(index, std::make_unique<interp::InterpExecutor>(subg)); - }); - _state = State::COMPILED; - return executors; - } + _model->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) { + auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph); - // Mode check - if (_options.he_profiling_mode) - checkProfilerConditions(); + // Mandatory passes + pass::PassRunner{} + .append(std::make_unique<pass::ConstantOutputPass>(subg)) + .append(std::make_unique<pass::OddOutputPass>(subg)) + .run(); + + // Optimizations + pass::PassRunner{}.append(std::make_unique<pass::UnusedOperandEliminationPass>(subg)).run(); + }); /*************************************************** * Backend independent analysis & optimization phase ***************************************************/ - auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options.graph_dump_level); + // TODO Handle dump level for each model + auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options->graph_dump_level); + onert::dumper::dot::DotDumper dot_dumper(dump_level); + + // Tracing context + auto tracing_ctx = std::make_unique<util::TracingCtx>(); // Lower: Assign backend std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::LoweredGraph>> lowered_subgs; - _subgraphs->iterate([&](const ir::SubgraphIndex &index, ir::Graph &subg) { - _options.is_primary_subgraph = (index == ir::SubgraphIndex{0}); - onert::dumper::dot::DotDumper dot_dumper(subg, dump_level); - dot_dumper.dump(nnfw::misc::str("before_lower_subg-", index.value())); - - // Lower: Assign backend - lowered_subgs[index] = std::make_unique<compiler::LoweredGraph>(subg, _options); - - // Check backend(s) for subgraph support FP16 - bool backends_support_fp16 = true; - auto &contexts = (*lowered_subgs[index]).backend_contexts(); - for (auto it = contexts.begin(); it != contexts.end(); it++) - { - // Controlflow backend is not for actual computaion of operations so it is an exception - if (it->first->config()->id() != backend::controlflow::Config::ID) - backends_support_fp16 &= it->first->config()->supportFP16(); - } + { + _model->iterate([&](const ir::SubgraphIndex &subg_index, ir::IGraph &graph) { + auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph); - if (_options.fp16_enable && backends_support_fp16) - { - // NOTE: the only acl_cl backend enables fp16 mode - Fp32ToFp16Converter(*lowered_subgs[index]).run(); - } + // Lower: Assign backend + lowered_subgs[subg_index] = std::make_unique<compiler::LoweredGraph>(subg, *_options); + // Set tracing_ctx for copied graph + tracing_ctx->setSubgraphIndex(&(lowered_subgs[subg_index]->graph()), subg_index.value()); + }); + } - subg.setSubgraphs(nullptr); - }); + _model.reset(); - _subgraphs.reset(); + for (const auto &pair : lowered_subgs) + { + const auto &subg_index = pair.first; + const auto &lowered_subg = pair.second; + dot_dumper.dump(*lowered_subg, nnfw::misc::str("after_lower_subg-", subg_index.value())); + } // Shape inference. { + // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called + // recursively + std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers = + createStaticShapeInferers(lowered_subgs); + const auto primary_subg_idx = ir::SubgraphIndex{0}; - StaticShapeInferer inferer(primary_subg_idx, lowered_subgs); - lowered_subgs.at(primary_subg_idx) - ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) { - auto has_dynamic_tensor = inferer.infer(op_seq); - op_seq.has_dynamic_tensor(has_dynamic_tensor); - }); - inferer.dump(); - } + inferers.at(primary_subg_idx)->infer(); - /************************************************************* - * Backend independent analysis & optimization phase finished - *************************************************************/ + for (const auto &pair_inferer : inferers) + { + const auto inferer = pair_inferer.second.get(); + inferer->dump(); + } + } - // operation validation - for (auto &pair : lowered_subgs) + // Shape validation + // TODO Move shape independent feature check from ShapeValidator to OperationValidator + // TODO Move ShapeValidator into shape inference + // - Check input tensor shape validation + // - Check parameter value validation which valid value is depend on input tensor shape + // - Output tensor shape validation check is needless because + // static/dynamic shape inferer will make valid output shape + for (const auto &pair : lowered_subgs) { auto &lowered_subg = pair.second; - compiler::OperationValidator{lowered_subg->graph()}(); + compiler::ShapeValidator{lowered_subg->graph()}(); } - executors = std::make_shared<exec::ExecutorMap>(); - for (auto &pair : lowered_subgs) + /************************************************************* + * Backend independent analysis & optimization phase finished + *************************************************************/ + auto executors = std::make_shared<exec::SingleModelExecutors>(); + for (auto &&pair : lowered_subgs) { - const auto &subg_index = pair.first; + auto const model_index = ir::ModelIndex{0}; + auto const subg_index = pair.first; auto &lowered_subg = pair.second; - auto indexed_ranks = lowered_subg->indexed_ranks(); - - _options.is_primary_subgraph = (subg_index == ir::SubgraphIndex{0}); + auto const indexed_ranks = lowered_subg->indexed_ranks(); - onert::dumper::dot::DotDumper dot_dumper_lowered(lowered_subg.get(), dump_level); - dot_dumper_lowered.dump("after_lower_subg-" + std::to_string(subg_index.value())); - - ir::OperationDumper dumper("START SUBGRAPH " + std::to_string(subg_index.value())); + ir::OperationDumper dumper("Executor generation of Subgraph " + + std::to_string(subg_index.value())); lowered_subg->graph().operations().iterate( - [&](const ir::OperationIndex &, const ir::Operation &op) { op.accept(dumper); }); + [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); }); + + ExecutorFactoryArgs args; + args.tracing_ctx = tracing_ctx.get(); + args.options = _options; + args.model_index = model_index; + args.custom_kernel_builder = custom_kernel_builder; auto executor = std::unique_ptr<exec::IExecutor>{ - ExecutorFactory::get().create(std::move(lowered_subg), _options, executors)}; + ExecutorFactory::get().create(std::move(lowered_subg), executors, args)}; executor->setIndexedRanks(indexed_ranks); - executors->insert(std::make_pair(subg_index, std::move(executor))); + executors->emplace(model_index, subg_index, std::move(executor)); } /******************************** * Code generation phase finished ********************************/ - _state = State::COMPILED; - return executors; -} - -bool Compiler::checkCompilable() -{ - // Disable compile phase - // When ready to use interpreter backend, remove this config and use backend setting - if (_options.disable_compile) - { - return false; - } - - // TODO check unspecified operand shape - - // Check compilable parameter - for (uint32_t i = 0; i < _subgraphs->count(); ++i) - { - auto graph = _subgraphs->at(ir::SubgraphIndex{i}); - ParamChecker paramChecker{graph}; - paramChecker(); - if (paramChecker.haveNoneConstParam()) - { - return false; - } - } - - return true; + return std::make_shared<CompilerArtifact>(executors, std::move(tracing_ctx)); } } // namespace compiler - } // namespace onert diff --git a/runtime/onert/core/src/compiler/CompilerFactory.cc b/runtime/onert/core/src/compiler/CompilerFactory.cc new file mode 100644 index 000000000..3e1209a52 --- /dev/null +++ b/runtime/onert/core/src/compiler/CompilerFactory.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compiler/CompilerFactory.h" + +#include "MultiModelCompiler.h" +#include "train/TrainingCompiler.h" +#include "compiler/Compiler.h" + +namespace onert +{ +namespace compiler +{ + +CompilerFactory &CompilerFactory::get() +{ + static CompilerFactory singleton; + return singleton; +} + +std::unique_ptr<ICompiler> CompilerFactory::create(const std::shared_ptr<ir::NNPkg> &nnpkg, + CompilerOptions *copts, + const ir::train::TrainingInfo *training_info) +{ + // Returing compiler for training + if (training_info) + return std::make_unique<train::TrainingCompiler>(nnpkg, copts, *training_info); + + // Returing compiler for inference + if (nnpkg->model_count() == 1) + return std::make_unique<Compiler>(nnpkg, copts); + + return std::make_unique<MultiModelCompiler>(nnpkg, copts); +} + +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/CompilerHelpers.h b/runtime/onert/core/src/compiler/CompilerHelpers.h new file mode 100644 index 000000000..798334b3b --- /dev/null +++ b/runtime/onert/core/src/compiler/CompilerHelpers.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_COMPILER_HELPERS_H__ +#define __ONERT_COMPILER_COMPILER_HELPERS_H__ + +#include <compiler/ILoweredGraph.h> +#include <compiler/StaticShapeInferer.h> +#include <ir/Index.h> + +#include <memory> +#include <unordered_map> + +namespace onert +{ +namespace compiler +{ + +/** + * @brief Create a shape inferer map for a lowered model + * @param[in] lowered_subgs lowered model map + * @return Shape inferer map + */ +template <typename LoweredGraphType, + typename = std::enable_if_t<std::is_base_of<ILoweredGraph, LoweredGraphType>::value>> +static std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> +createStaticShapeInferers( + const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraphType>> &lowered_subgs) +{ + std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> lsubgs; + for (auto &&e : lowered_subgs) + lsubgs[e.first] = e.second.get(); + return StaticShapeInferer::createStaticShapeInferers(lsubgs); +} + +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_COMPILER_HELPERS_H__ diff --git a/runtime/onert/core/src/compiler/CompilerOptions.cc b/runtime/onert/core/src/compiler/CompilerOptions.cc new file mode 100644 index 000000000..c5aee1956 --- /dev/null +++ b/runtime/onert/core/src/compiler/CompilerOptions.cc @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compiler/CompilerOptions.h" + +#include "../backend/builtin/Backend.h" + +#include "util/ConfigSource.h" +#include "util/logging.h" + +#include <misc/string_helpers.h> + +namespace +{ + +using namespace onert; + +std::string getOpBackends(std::unordered_map<ir::OpCode, std::string> &opcode_to_backend) +{ + std::unordered_map<ir::OpCode, std::string>::iterator it; + std::string opbackends; + + for (it = opcode_to_backend.begin(); it != opcode_to_backend.end(); ++it) + { + if (!opbackends.empty()) + opbackends = opbackends + ", "; + + auto opcode = it->first; + const std::string opname = ir::toString(opcode); + opbackends += opname + "=" + it->second; + } + return opbackends; +} + +} // namespace + +namespace onert +{ +namespace compiler +{ + +void ManualSchedulerOptions::setBackendMap(const std::string &str) +{ + // TODO Support multiple subgraphs for manual scheduling + auto key_val_list = nnfw::misc::split(str, ';'); + for (const auto &key_val_str : key_val_list) + { + if (key_val_str.empty()) + { + continue; + } + + auto key_val = nnfw::misc::split(key_val_str, '='); + if (key_val.size() != 2) + throw std::runtime_error{"Invalid key-value pair"}; + + const auto &key_str = key_val.at(0); + const auto &val = key_val.at(1); + auto key = static_cast<uint32_t>(std::stoi(key_str)); + this->index_to_backend.emplace(ir::OperationIndex{key}, val); + } +} + +std::unique_ptr<CompilerOptions> CompilerOptions::fromGlobalConfig() +{ + auto o = std::make_unique<CompilerOptions>(); + o->backend_list = nnfw::misc::split(util::getConfigString(util::config::BACKENDS), ';'); + o->graph_dump_level = util::getConfigInt(util::config::GRAPH_DOT_DUMP); + o->executor = util::getConfigString(util::config::EXECUTOR); + o->he_scheduler = util::getConfigBool(util::config::USE_SCHEDULER); + o->he_profiling_mode = util::getConfigBool(util::config::PROFILING_MODE); + o->fp16_enable = util::getConfigBool(util::config::FP16_ENABLE); + o->workspace_dir = util::getConfigString(util::config::WORKSPACE_DIR); + { + // Backend for all + auto &ms_options = o->manual_scheduler_options; + + // Default value for op_backend_all is first element in the backend list + ms_options.backend_for_all = util::getConfigString(util::config::OP_BACKEND_ALLOPS); + +// Opcode to Backend +#define OP(OpName) \ + { \ + const auto &backend_str = util::getConfigString(util::config::OP_BACKEND_##OpName); \ + if (!backend_str.empty()) \ + { \ + ms_options.opcode_to_backend[ir::OpCode::OpName] = backend_str; \ + } \ + } +#include "ir/Operations.lst" +#undef OP + + // Index to Backend + auto map_str = util::getConfigString(util::config::OP_BACKEND_MAP); + ms_options.setBackendMap(map_str); + } + return o; +} + +void CompilerOptions::forceInternalOptions() +{ + // Set control flow backend for control flow operators + auto &builtin_id = backend::builtin::Config::ID; + manual_scheduler_options.opcode_to_backend[ir::OpCode::If] = builtin_id; + manual_scheduler_options.opcode_to_backend[ir::OpCode::While] = builtin_id; + manual_scheduler_options.opcode_to_backend[ir::OpCode::Permute] = builtin_id; + + // FIXME This is a workaround for bcq operations, should remove it + manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQFullyConnected] = "bcq"; + manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQGather] = "bcq"; + + // FIXME This is a workaround for bulk operations, should remove it + manual_scheduler_options.opcode_to_backend[ir::OpCode::Bulk] = "trix"; +} + +void CompilerOptions::verboseOptions() +{ + VERBOSE(Compiler) << std::boolalpha << "==== Compiler Options ====" << std::endl; + VERBOSE(Compiler) << "backend_list : " + << nnfw::misc::join(backend_list.begin(), backend_list.end(), "/") << std::endl; + VERBOSE(Compiler) << "graph_dump_level : " << graph_dump_level << std::endl; + VERBOSE(Compiler) << "executor : " << executor << std::endl; + VERBOSE(Compiler) << "manual backend_for_all : " << manual_scheduler_options.backend_for_all + << std::endl; + VERBOSE(Compiler) << "manual_scheduler_options : " + << getOpBackends(manual_scheduler_options.opcode_to_backend) << std::endl; + VERBOSE(Compiler) << "he_scheduler : " << he_scheduler << std::endl; + VERBOSE(Compiler) << "he_profiling_mode : " << he_profiling_mode << std::endl; + VERBOSE(Compiler) << "fp16_enable : " << fp16_enable << std::endl + << std::noboolalpha; +} + +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.cc b/runtime/onert/core/src/compiler/ExecutorFactory.cc index 062c6c9c3..eff3f5abe 100644 --- a/runtime/onert/core/src/compiler/ExecutorFactory.cc +++ b/runtime/onert/core/src/compiler/ExecutorFactory.cc @@ -16,24 +16,29 @@ #include "ExecutorFactory.h" +#include "Linear.h" +#include "../backend/builtin/BackendContext.h" +#include "../backend/builtin/Config.h" +#include "../backend/builtin/UserTensor.h" +#include "../backend/builtin/train/BackendContext.h" +#include "../dumper/text/GraphDumper.h" +#include "../exec/DataflowExecutor.h" +#include "../exec/ExecTime.h" +#include "../exec/ExecutionObservers.h" +#include "../exec/LinearExecutor.h" +#include "../exec/MinMaxRecorder.h" +#include "../exec/ParallelExecutor.h" +#include "../exec/train/TrainableExecutor.h" +#include "../ir/OperationCloner.h" + +#include <backend/IPortableTensor.h> +#include <backend/train/TrainableBackendContext.h> +#include <backend/train/ITrainableBackend.h> +#include <compiler/BackendManager.h> +#include <compiler/ExecutionBuilder.h> +#include <util/TracingCtx.h> + #include <functional> -#include "exec/ExecutionObservers.h" -#include "exec/LinearExecutor.h" -#include "exec/DataflowExecutor.h" -#include "exec/ParallelExecutor.h" -#include "compiler/BackendManager.h" -#include "compiler/ExecutionBuilder.h" -#include "exec/ExecTime.h" -#include "compiler/Linear.h" -#include "compiler/TensorBuilders.h" -#include "backend/IConstantInitializer.h" -#include "backend/IKernelGenerator.h" -#include "backend/IOptimizer.h" -#include "backend/ITensorRegister.h" -#include "backend/controlflow/Config.h" -#include "backend/controlflow/KernelGenerator.h" -#include "backend/controlflow/UserTensor.h" -#include "backend/controlflow/TensorBuilder.h" #include <memory> namespace onert @@ -46,7 +51,7 @@ class SyncFunction final : public exec::IFunction public: virtual ~SyncFunction() = default; SyncFunction(std::unique_ptr<exec::IFunction> fn, const std::shared_ptr<backend::IConfig> config) - : _fn{std::move(fn)}, _config{config} + : _fn{std::move(fn)}, _config{config} { assert(_fn); assert(_config); @@ -65,21 +70,221 @@ private: std::shared_ptr<backend::IConfig> _config; }; -// TODO Think of a better way to manage TensorManagers -backend::TensorManagerSet createTensorManagerSet(const compiler::TensorBuilders &tensor_builders) +using DeallocList = std::vector<backend::ITensor *>; +// Deallocation after execution of an operation used by Linear Executor +class DeallocFunction final : public exec::IFunction +{ +public: + DeallocFunction(const DeallocList &tensors) : _dealloc_list{tensors} {} + + void run() override + { + for (auto &&tensor : _dealloc_list) + { + if (!tensor->is_dynamic()) + continue; + tensor->deallocBuffer(); + } + } + +private: + DeallocList _dealloc_list; +}; + +// TODO Unify initializeSubgraphIOTensors +void initializeSubgraphIOTensors(compiler::ILoweredGraph &lowered_graph, + const backend::BackendContexts &backend_contexts, + const ir::OperandIndexSequence &indices) +{ + // TODO Store builtin backend in BackendContext + std::shared_ptr<backend::builtin::TensorRegistry> builtin_tensor_reg; + for (const auto &e : backend_contexts) + { + auto backend = e.first; + auto &context = e.second; + if (backend->config()->id() == backend::builtin::Config::ID) + { + builtin_tensor_reg = + std::dynamic_pointer_cast<backend::builtin::TensorRegistry>(context->tensor_registry); + } + } + assert(builtin_tensor_reg); + + for (auto &&ind : indices) + { + const auto &operand = lowered_graph.graph().operands().at(ind); + auto tensor = std::make_unique<backend::builtin::IOTensor>( + operand.info(), + ir::Layout::NHWC /* FIXME find operation for this operand and use frontend_layout */ + ); + + // Add tensor to builtin TensorRegistry. + builtin_tensor_reg->setNativeIOTensor(ind, std::move(tensor)); + } +} + +void initializeSubgraphIOTensors(compiler::ILoweredGraph &lowered_graph, + const backend::train::TrainableBackendContexts &backend_contexts, + const ir::OperandIndexSequence &indices) +{ + std::shared_ptr<backend::builtin::train::TensorRegistry> builtin_tensor_reg; + for (const auto &e : backend_contexts) + { + auto backend = e.first; + auto &context = e.second; + if (backend->config()->id() == backend::builtin::Config::ID) + { + builtin_tensor_reg = std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>( + context->tensor_registry()); + } + } + assert(builtin_tensor_reg); + + for (auto &&ind : indices) + { + const auto &operand = lowered_graph.graph().operands().at(ind); + auto tensor = std::make_unique<backend::builtin::IOTensor>( + operand.info(), + ir::Layout::NHWC /* FIXME find operation for this operand and use frontend_layout */ + ); + + // Add tensor to builtin TensorRegistry. + builtin_tensor_reg->setNativeIOTensor(ind, std::move(tensor)); + } +} + +backend::BackendContexts +createBackendContexts(compiler::ILoweredGraph &lgraph, bool linear_executor, + std::shared_ptr<backend::custom::IKernelBuilder> custom_kernel_builder) { - backend::TensorManagerSet tensor_mgrs; - for (auto &tensor_builder : tensor_builders) + backend::BackendContexts contexts; + std::unordered_map<const backend::Backend *, backend::ContextData> context_data_map; + + // Generate partial graphs for each backend + auto init_context_data = [&](const backend::Backend *backend) { + auto &data = context_data_map[backend]; + auto graph = std::make_unique<ir::Graph>(); + graph->setLayout(lgraph.graph().layout()); + data.graph = std::move(graph); + }; + + auto &whole_graph = lgraph.graph(); + // Separate operands into partial graphs + whole_graph.operands().iterate([&](const ir::OperandIndex &operand_ind, ir::Operand &operand) { + auto &operand_li = lgraph.lower_info().operand; + const auto &def_factors = operand_li.at(operand_ind).def_factors(); + if (def_factors.size() == 0) // Ignore unused tensor + return; + const auto &def_factor = def_factors.getOnlyElement(); + const auto backend = def_factor.backend(); + if (context_data_map.find(backend) == context_data_map.end()) + init_context_data(backend); + + auto &partial_graph = *context_data_map[backend].graph; + auto &operand_layouts = context_data_map[backend].operand_layouts; + assert(operand_layouts.find(operand_ind) == operand_layouts.end()); + operand_layouts[operand_ind] = def_factor.layout(); + + // Copy the operand and insert it to the partial graph + auto new_operand = std::make_unique<ir::Operand>(operand); + new_operand->clearDefUse(); + operand.releaseData(); // Deref data of LoweredGraph + auto new_operand_ind = partial_graph.addOperand(operand_ind, std::move(new_operand)); + UNUSED_RELEASE(new_operand_ind); + assert(new_operand_ind == operand_ind); + }); + // Separate operations into partial graphs + whole_graph.operations().iterate( + [&](const ir::OperationIndex &op_ind, const ir::IOperation &operation) { + auto &op_li = lgraph.lower_info().operation; + auto backend = op_li.at(op_ind).backend(); + if (context_data_map.find(backend) == context_data_map.end()) + init_context_data(backend); + + auto &partial_graph = *context_data_map[backend].graph; + auto &external_operands = context_data_map[backend].external_operands; + auto &operand_layouts = context_data_map[backend].operand_layouts; + + { + // Add missing operands (externals) + auto io_list = (operation.getInputs() + operation.getOutputs()) | ir::Remove::DUPLICATED | + ir::Remove::UNDEFINED; + for (auto &&operand_ind : io_list) + { + if (partial_graph.operands().exist(operand_ind)) + continue; + + // Copy the operand and insert it to the partial graph + const auto &operand = whole_graph.operands().at(operand_ind); + auto new_operand = std::make_unique<ir::Operand>(operand); + new_operand->clearDefUse(); + auto new_operand_ind = partial_graph.addOperand(operand_ind, std::move(new_operand)); + UNUSED_RELEASE(new_operand_ind); + assert(new_operand_ind == operand_ind); + + auto layout = + lgraph.lower_info().operand.at(operand_ind).def_factors().getOnlyElement().layout(); + assert(operand_layouts.find(operand_ind) == operand_layouts.end()); + operand_layouts[operand_ind] = layout; + external_operands.add(operand_ind); + } + + auto new_op_ind = partial_graph.addOperation(op_ind, clone(operation)); + UNUSED_RELEASE(new_op_ind); + assert(new_op_ind == op_ind); + } + }); + + // Create contexts + auto whole_op_order = lgraph.graph().topolSortOperations(); + for (auto &&pair : context_data_map) { - auto s_tensor_manager = tensor_builder->releaseStaticTensorManager(); - if (s_tensor_manager != nullptr) - tensor_mgrs.insert(std::move(s_tensor_manager)); + auto backend = pair.first; + auto &data = pair.second; + // Handle graph input/outputs or external tensors + data.graph->operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &operand) { + if (whole_graph.getInputs().contains(ind) || whole_graph.getOutputs().contains(ind)) + data.external_operands.add(ind); + // Inputs are either "graph input" or "no def op and non-constant" + if (whole_graph.getInputs().contains(ind) || + (!operand.getDef().valid() && !operand.isConstant())) + // Outputs are either "graph output" or "no uses" + data.graph->addInput(ind); + if (whole_graph.getOutputs().contains(ind) || operand.getUses().size() == 0) + data.graph->addOutput(ind); + }); + VERBOSE(ExecutorFactory) << "createBackendContexts: partial graph for backend=" + << backend->config()->id() << std::endl; + dumper::text::dumpGraph(*data.graph); + + std::copy_if(whole_op_order.begin(), whole_op_order.end(), std::back_inserter(data.op_order), + [&](const auto &ind) { return data.graph->operations().exist(ind); }); + data.is_linear_executor = linear_executor; + data.custom_kernel_builder = custom_kernel_builder; + contexts.emplace(backend, backend->newContext(std::move(data))); + } + return contexts; +} - auto d_tensor_manager = tensor_builder->releaseDynamicTensorManager(); - if (d_tensor_manager != nullptr) - tensor_mgrs.insert(std::move(d_tensor_manager)); +template <typename Context> +std::deque<std::pair<const backend::Backend *, Context *>> orderBackendContext( + const std::unordered_map<const backend::Backend *, std::unique_ptr<Context>> &tbackend_contexts) +{ + std::deque<std::pair<const backend::Backend *, Context *>> ordered_contexts; + + for (auto &&pair : tbackend_contexts) + { + // NOTE builtin backend must be processed lastly. + // This is because of Permute layer's specialty which is the only operation that could have + // different ITensor objects for the input and the output. And it requires all other backends' + // tensors are ready to use. + if (pair.first->config()->id() == "builtin") + ordered_contexts.emplace_back(pair.first, pair.second.get()); + else + ordered_contexts.emplace_front(pair.first, pair.second.get()); } - return tensor_mgrs; + + return ordered_contexts; } } // namespace @@ -106,412 +311,582 @@ ExecutorFactory::ExecutorFactory() } exec::IExecutor *ExecutorFactory::create(std::unique_ptr<compiler::LoweredGraph> lowered_graph, - const compiler::CompilerOptions &options, - const std::shared_ptr<exec::ExecutorMap> &executor_map) + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args) { - return _map.at(options.executor)(std::move(lowered_graph), options, executor_map); + assert(args.options != nullptr); + return _map.at(args.options->executor)(std::move(lowered_graph), executors, args); } -void ExecutorFactory::initializeBackendContext(compiler::LoweredGraph *lowered_graph) +void ExecutorFactory::prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph, + const backend::BackendContexts &backend_contexts) { - struct Entry - { - std::vector<backend::BackendContext::OperationInfo> operation_list; - std::vector<ir::OperandIndex> operand_list; - }; - std::unordered_map<const backend::Backend *, Entry> backend_assets; - - // Build lists for operations - lowered_graph->op_seqs().iterate( - [&](const ir::OpSequenceIndex &op_seq_index, const ir::OpSequence &op_seq) { - auto &op_seq_li = lowered_graph->getLowerInfo()->op_seq; - auto backend = op_seq_li.at(op_seq_index)->backend(); - for (auto &operation_idx : op_seq.operations()) + TensorRegistries tensor_regs{backend_contexts, true}; + + lowered_graph.graph().operations().iterate( + [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) { + auto lower_info = lowered_graph.lower_info().operation.getRawPtr(op_ind); + auto &backend_ctx = backend_contexts.at(lower_info->backend()); + for (auto &&ind : + (op.getInputs() + op.getOutputs()) | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + { + // If an Operation's input/output tensor does not have an own tensor object, + // it must be using migrant tensors, so find the tensor from other tensor registries and + // register it to the current tensor registry if it is portable + if (!backend_ctx->tensor_registry->getITensor(ind)) { - backend_assets[backend].operation_list.emplace_back(operation_idx, op_seq.getLayout()); + auto tensor = tensor_regs.getITensor(ind); + assert(tensor); // The tensor must have been registered + auto ptensor = dynamic_cast<backend::IPortableTensor *>(tensor); + if (ptensor) + backend_ctx->tensor_registry->setMigrantTensor(ind, ptensor); } - }); + } + }); +} - // Build lists for operands - lowered_graph->graph().operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &) { - const auto lower_info = lowered_graph->getLowerInfo(ind); - for (auto factor : lower_info->def_factors()) +void ExecutorFactory::prepareBuiltinBackend(const TensorRegistries &tensor_regs, + const std::shared_ptr<exec::IExecutors> &executors, + const backend::BackendContexts &backend_contexts, + const ir::ModelIndex &index) +{ + for (auto &&pair : backend_contexts) + { + auto builtin_context = dynamic_cast<backend::builtin::BackendContext *>(pair.second.get()); + if (builtin_context != nullptr) { - auto backend = factor.backend(); - backend_assets[backend].operand_list.emplace_back(ind); + auto builtin_kernel_gen = builtin_context->kernel_gen; + builtin_kernel_gen->setTensorRegistries(tensor_regs); + builtin_kernel_gen->setExecutors(executors); + builtin_kernel_gen->setModelIndex(index); } - }); + } +} - for (auto &pair : backend_assets) +std::deque<std::pair<const backend::Backend *, backend::BackendContext *>> +ExecutorFactory::orderBackendContext(const backend::BackendContexts &backend_contexts) +{ + std::deque<std::pair<const backend::Backend *, backend::BackendContext *>> ordered_contexts; + for (auto &&pair : backend_contexts) { - auto backend = pair.first; - auto &arg = pair.second; - lowered_graph->backend_contexts().at(backend)->initialize(arg.operation_list, arg.operand_list); + // NOTE builtin backend must be processed lastly. + // This is because of Permute layer's specialty which is the only operation that could have + // different ITensor objects for the input and the output. And it requires all other backends' + // tensors are ready to use. + if (pair.first->config()->id() == "builtin") + ordered_contexts.emplace_back(pair.first, pair.second.get()); + else + ordered_contexts.emplace_front(pair.first, pair.second.get()); } + return ordered_contexts; } -void ExecutorFactory::runTensorRegistration(compiler::LoweredGraph *lowered_graph, - const std::vector<ir::OpSequenceIndex> &order) +exec::IExecutor * +ExecutorFactory::createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args) { - for (const auto index : order) + const auto options = args.options; + const auto &model_index = args.model_index; + const auto tracing_ctx = args.tracing_ctx; + auto custom_kernel_builder = args.custom_kernel_builder; + auto &graph = lowered_graph->graph(); + + backend::BackendContexts backend_contexts = + createBackendContexts(*lowered_graph, options->executor == "Linear", custom_kernel_builder); + + TensorRegistries tensor_regs{backend_contexts, true}; + + initializeSubgraphIOTensors( + *lowered_graph, backend_contexts, + (lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs()) | + ir::Remove::DUPLICATED | ir::Remove::UNDEFINED); + + // linearize + auto order = Linear::linearize(*lowered_graph); + Linear::dump(*lowered_graph, order); + + for (auto &&pair : backend_contexts) + { + pair.second->genTensors(); + } + + prepareMigrantTensors(*lowered_graph, backend_contexts); + + // Give some runtime objects to builtin KernelGenerator + prepareBuiltinBackend(tensor_regs, executors, backend_contexts, model_index); + + ExecutionBuilder builder; + + // Adjust the order of backends for the upcoming iteration + auto ordered_contexts = orderBackendContext(backend_contexts); + + // Simulate the execution for deallocation of tensors + std::unordered_map<ir::OperationIndex, DeallocList> dealloc_list_map; { - const auto &op_seq = lowered_graph->op_seqs().at(index); - const auto backend = lowered_graph->getLowerInfo(index)->backend(); - const auto tensor_register = lowered_graph->backend_contexts().at(backend)->tensor_register; - auto tensor_builder = lowered_graph->backend_contexts().at(backend)->tensor_builder; - auto model_io = lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs(); + ir::OperandIndexMap<uint32_t> uses_map; + ir::OperandIndexSequence constants; + + auto model_io = + (graph.getInputs() + graph.getOutputs()) | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + + // Prepare scanning + graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) { + uses_map[ind] = obj.getUses().size(); + + if (obj.isConstant()) + constants.append(ind); + }); - if (tensor_register) + // A trick to consider constants as an execption + for (const auto &ind : constants) { - // Custom registration - tensor_register->registerTensors(op_seq, lowered_graph->getLowerInfo()); + uses_map[ind]++; } - else + + for (const auto &op_ind : order) { - // Default registration - for (const auto op_idx : op_seq) + const auto &op = graph.operations().at(op_ind); + auto op_inputs = op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED; + auto op_outputs = op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED; + + for (const auto &ind : op_inputs) { - const auto &op = lowered_graph->graph().operations().at(op_idx); - for (const auto &index : (op.getInputs() | ir::Remove::UNDEFINED) + op.getOutputs()) + const auto &operand = graph.operands().at(ind); + assert(uses_map.find(ind) != uses_map.end()); + assert(uses_map[ind] > 0); + uses_map[ind]--; + if (uses_map[ind] == 0 && !operand.info().isVariable() && !model_io.contains(ind)) { - if (!tensor_builder->isRegistered(index) && !model_io.contains(index)) - { - const auto &operand_lower_info = - lowered_graph->getLowerInfo(index)->def_factors().getOnlyElement(); - - // E.g., permute (CPU) -> tensor A -> MaxPool2D(acl_cl) - // op.getOutputs() of permute (CPU) returns tensor A - // but tensor A belongs to the backend of acl_cl. - // So, we have to make this tensor NOT registered for CPU. - if (operand_lower_info.backend() != backend) - continue; - - const auto &obj = lowered_graph->graph().operands().at(index); - const auto frontend_layout = op_seq.getLayout(); - const auto backend_layout = operand_lower_info.layout(); - ir::OperandInfo backend_info{permuteShape(obj.shape(), frontend_layout, backend_layout), - obj.typeInfo(), obj.info().memAllocType(), - obj.isConstant()}; - tensor_builder->registerTensorInfo(index, backend_info, backend_layout); - } + dealloc_list_map[op_ind].emplace_back(tensor_regs.getITensor(ind)); } } } - } -} -std::vector<std::shared_ptr<backend::ITensor>> -ExecutorFactory::initializeModelIOTensors(compiler::LoweredGraph &lowered_graph, - const ir::OperandIndexSequence &indices) -{ - std::vector<std::shared_ptr<backend::ITensor>> ret; + // Dispose and validate + for (const auto &ind : constants) + { + --uses_map[ind]; + } + + assert( + std::all_of(uses_map.begin(), uses_map.end(), + [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; })); + } - // TODO Store controlflow backend in BackendContext - std::shared_ptr<backend::controlflow::TensorBuilder> cf_tensor_builder; - std::shared_ptr<backend::controlflow::TensorRegistry> cf_tensor_reg; - for (const auto &e : lowered_graph.backend_contexts()) + // Generate kernels + for (auto &&pair : ordered_contexts) { - auto backend = e.first; - auto &context = e.second; - if (backend->config()->id() == backend::controlflow::Config::ID) + auto codes = pair.second->genKernels(); + for (auto &&pair : codes) { - cf_tensor_builder = - std::dynamic_pointer_cast<backend::controlflow::TensorBuilder>(context->tensor_builder); - cf_tensor_reg = - std::dynamic_pointer_cast<backend::controlflow::TensorRegistry>(context->tensor_registry); + auto &op_ind = pair.first; + auto &fn_seq = pair.second; + auto &op = lowered_graph->graph().operations().at(op_ind); + auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind); + if (options->he_profiling_mode) + fn_seq->wrap<SyncFunction>(lower_info->backend()->config()); + if (!dealloc_list_map[op_ind].empty()) + fn_seq->append(std::make_unique<DeallocFunction>(dealloc_list_map[op_ind])); + builder.append(op_ind, {op_ind, &op, lower_info, std::move(fn_seq)}); } } - assert(cf_tensor_builder); - assert(cf_tensor_reg); - for (auto ind : indices) + auto code_map = builder.releaseCodeMap(); + + auto exec = new exec::LinearExecutor{std::move(lowered_graph), + std::move(backend_contexts), + tensor_regs, + std::move(code_map), + order, + tracing_ctx}; + + if (!options->workspace_dir.empty()) { - const auto &operand = lowered_graph.graph().operands().at(ind); - auto tensor = std::make_shared<backend::controlflow::UserTensor>( - operand.info(), - ir::Layout::NHWC, /* FIXME find op_seq for this operand and use frontend_layout */ - cf_tensor_builder->dynamicTensorManager()); - - // Add tensor to controlflow TensorRegistry. - cf_tensor_reg->setNativeUserTensor(ind, tensor); - ret.push_back(tensor); + exec->addObserver( + std::make_unique<exec::TracingObserver>(options->workspace_dir, exec->graph(), tracing_ctx)); + exec->addObserver(std::make_unique<exec::MinMaxRecorder>(options->workspace_dir, exec->graph(), + exec->getBackendContexts())); } - return ret; -} -void ExecutorFactory::prepareExternalTensors(compiler::LoweredGraph &lowered_graph) -{ - TensorRegistries tensor_regs{lowered_graph.backend_contexts(), true}; - - lowered_graph.op_seqs().iterate( - [&](const ir::OpSequenceIndex &op_seq_index, const ir::OpSequence &op_seq) { - auto lower_info = lowered_graph.getLowerInfo(op_seq_index); - auto &backend_ctx = lowered_graph.backend_contexts().at(lower_info->backend()); - for (auto ind : (op_seq.getInputs() + op_seq.getOutputs()) | ir::Remove::DUPLICATED | - ir::Remove::UNDEFINED) - { - // If an OpSequence input/output tensor does not have a own tensor object, - // it must be using external tensors, so find the tensor from other tensor builders and - // set the tensor to this tensor builder if portable - if (!backend_ctx->tensor_registry->getITensor(ind)) - { - auto tensor = tensor_regs.getITensor(ind); - assert(tensor); // The tensor must have been registered - auto ptensor = std::dynamic_pointer_cast<backend::IPortableTensor>(tensor); - if (ptensor) - backend_ctx->tensor_registry->setMigrantTensor(ind, ptensor); - } - } - }); + return exec; } exec::IExecutor * -ExecutorFactory::createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, - const compiler::CompilerOptions &options, - const std::shared_ptr<exec::ExecutorMap> &executor_map) +ExecutorFactory::createDataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args, bool parallel) { - const auto &backend_contexts = lowered_graph->backend_contexts(); + const auto options = args.options; + const auto &model_index = args.model_index; + const auto tracing_ctx = args.tracing_ctx; + auto custom_kernel_builder = args.custom_kernel_builder; - initializeBackendContext(lowered_graph.get()); + backend::BackendContexts backend_contexts = + createBackendContexts(*lowered_graph, options->executor == "Linear", custom_kernel_builder); - // linearize - assert(!lowered_graph->graph().isBuildingPhase()); + TensorRegistries tensor_regs{backend_contexts, true}; - /************************************************* - * Backend dependent analysis & optimization phase - *************************************************/ + initializeSubgraphIOTensors( + *lowered_graph, backend_contexts, + (lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs()) | + ir::Remove::DUPLICATED | ir::Remove::UNDEFINED); - for (auto &pair : backend_contexts) + for (auto &&pair : backend_contexts) { - auto &optimizer = pair.second->optimizer; - if (optimizer) - optimizer->optimize(); + pair.second->genTensors(); } - /********************************************************** - * Backend dependent analysis & optimization phase finished - **********************************************************/ + prepareMigrantTensors(*lowered_graph, backend_contexts); - /*********************** - * Code generation phase - ***********************/ + // Give some runtime objects to builtin KernelGenerator + prepareBuiltinBackend(tensor_regs, executors, backend_contexts, model_index); - auto order = Linear::linearize(*lowered_graph); - runTensorRegistration(lowered_graph.get(), order); + ExecutionBuilder builder; + + // Adjust the order of backends for the upcoming iteration + auto ordered_contexts = orderBackendContext(backend_contexts); - std::vector<std::shared_ptr<backend::ITensor>> input_tensors; - std::vector<std::shared_ptr<backend::ITensor>> output_tensors; - if (options.is_primary_subgraph) + // Generate kernels + for (auto &&pair : ordered_contexts) { - input_tensors = initializeModelIOTensors(*lowered_graph, lowered_graph->graph().getInputs()); - output_tensors = initializeModelIOTensors(*lowered_graph, lowered_graph->graph().getOutputs()); + auto codes = pair.second->genKernels(); + for (auto &&pair : codes) + { + auto &op_ind = pair.first; + auto &fn_seq = pair.second; + auto &op = lowered_graph->graph().operations().at(op_ind); + auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind); + if (options->he_profiling_mode) + fn_seq->wrap<SyncFunction>(lower_info->backend()->config()); + builder.append(op_ind, {op_ind, &op, lower_info, std::move(fn_seq)}); + } } - Linear::dump(*lowered_graph, order); - Linear::planTensors(*lowered_graph, order); + auto code_map = builder.releaseCodeMap(); - TensorBuilders tensor_builders{lowered_graph->backend_contexts(), true}; - TensorRegistries tensor_regs{lowered_graph->backend_contexts(), true}; + exec::ExecutorBase *exec = nullptr; + if (parallel) + { + exec = new exec::ParallelExecutor{std::move(lowered_graph), std::move(backend_contexts), + tensor_regs, std::move(code_map), tracing_ctx}; + } + else + { + auto dataflow_exec = + new exec::DataflowExecutor{std::move(lowered_graph), std::move(backend_contexts), tensor_regs, + std::move(code_map), tracing_ctx}; + if (options->he_profiling_mode) + { + std::vector<const backend::Backend *> backends; + for (const auto &pair : backend_contexts) + { + backends.push_back(pair.first); + } + auto et = std::make_shared<exec::ExecTime>(backends); + std::unique_ptr<exec::IExecutionObserver> obs = + std::make_unique<exec::ProfileObserver>(et, dataflow_exec->graph()); + dataflow_exec->addObserver(std::move(obs)); + } + exec = dataflow_exec; + } - for (auto &tensor_builder : tensor_builders) + if (!options->workspace_dir.empty()) { - tensor_builder->prepare(); + exec->addObserver( + std::make_unique<exec::TracingObserver>(options->workspace_dir, exec->graph(), tracing_ctx)); } - prepareExternalTensors(*lowered_graph); + return exec; +} + +exec::IExecutor * +ExecutorFactory::create(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args, + const ir::train::TrainingInfo &training_info) +{ + assert(args.options != nullptr); + + if (args.options->executor != "Linear") + throw std::runtime_error("ExecutorFactory: TrainableExecutor supports only 'Linear' now"); - ExecutionBuilder builder; + return createTrainableExecutor(std::move(lowered_graph), executors, args, training_info); +} - // Generate kernels - lowered_graph->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &op_seq_index, - const ir::OpSequence &op_seq) { - auto lower_info = lowered_graph->getLowerInfo(op_seq_index); - auto kernel_gen = lowered_graph->backend_contexts().at(lower_info->backend())->kernel_gen; - // Set TensorBuilderSet and ExecutorMap to kernel_gen of control flow - auto cf_kernel_gen = dynamic_cast<backend::controlflow::KernelGenerator *>(kernel_gen.get()); - if (cf_kernel_gen != nullptr) +void ExecutorFactory::prepareMigrantTensors( + compiler::ILoweredGraph &lowered_graph, + const backend::train::TrainableBackendContexts &backend_contexts) +{ + train::TensorRegistries tensor_regs{backend_contexts, true}; + + lowered_graph.graph().operations().iterate( + [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) { + auto lower_info = lowered_graph.lower_info().operation.getRawPtr(op_ind); + auto &backend_ctx = backend_contexts.at(lower_info->backend()); + for (auto &&ind : + (op.getInputs() + op.getOutputs()) | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + { + // If an Operation's input/output tensor does not have an own tensor object, + // it must be using migrant tensors, so find the tensor from other tensor registries and + // register it to the current tensor registry if it is portable + if (!backend_ctx->tensor_registry()->getITensor(ind)) + { + auto tensor = tensor_regs.getITensor(ind); + assert(tensor); // The tensor must have been registered + auto ptensor = dynamic_cast<backend::IPortableTensor *>(tensor); + if (ptensor) + backend_ctx->tensor_registry()->setMigrantTensor(ind, ptensor); + } + } + }); +} + +exec::IExecutor *ExecutorFactory::createTrainableExecutor( + std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &, const ExecutorFactoryArgs &args, + const ir::train::TrainingInfo &training_info) +{ + const auto options = args.options; + const auto tracing_ctx = args.tracing_ctx; + auto custom_kernel_builder = args.custom_kernel_builder; + + auto &graph = lowered_graph->graph(); + + lowered_graph->trainable_graph().operations().iterate([](const onert::ir::OperationIndex &, + const onert::ir::IOperation &op) { + try { - cf_kernel_gen->setTensorRegistries(tensor_regs); - cf_kernel_gen->setExecutorMap(executor_map); + UNUSED_RELEASE(dynamic_cast<const ir::train::ITrainableOperation &>(op)); } - auto fn_seq = kernel_gen->generate(op_seq); - if (options.he_profiling_mode) + catch (std::bad_cast &) { - fn_seq->wrap<SyncFunction>(lower_info->backend()->config()); + throw std::runtime_error("ExecutorFactory: " + op.name() + " is not trainable operation yet"); } - builder.append(op_seq_index, {&op_seq, lower_info, std::move(fn_seq)}); }); - for (auto &tensor_builder : tensor_builders) - { - tensor_builder->allocate(); - } + // TODO Create context only once instead of replacing + backend::train::TrainableBackendContexts tbackend_contexts; + backend::BackendContexts base_backend_contexts = + createBackendContexts(*lowered_graph, true, custom_kernel_builder); - for (auto &pair : backend_contexts) + // Replace BackendContext with TrainbleBackendContext + for (auto &&pair : base_backend_contexts) { - pair.second->initConsts(); - } - - lowered_graph->graph().operands().iterate( - [](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); }); - - auto code_map = builder.releaseCodeMap(); - - for (auto &it : code_map) - { - auto op_seq_index = it.first; - auto &fn_seq = it.second.fn_seq; - - fn_seq->iterate([&](exec::IFunction &ifunc) { - ifunc.prepare(); - auto backend = lowered_graph->getLowerInfo(op_seq_index)->backend(); - auto tensor_builder = lowered_graph->backend_contexts().at(backend)->tensor_builder; - tensor_builder->postFunctionPrepare(); + auto ctx = pair.second.get(); + const auto &data = ctx->data(); + + // Create partial and trainable graphs + auto tgraph = std::make_unique<ir::train::TrainableGraph>(*data.graph); + data.graph->operations().iterate( + [&](const onert::ir::OperationIndex &op_index, const onert::ir::IOperation &) { + const auto &orig_tgraph = lowered_graph->trainable_graph(); + const auto &trainable_op = orig_tgraph.operation(op_index); + auto gen_index = tgraph->replaceOperation(op_index, trainable_op.clone()); + UNUSED_RELEASE(gen_index); + assert(gen_index == op_index); + }); + data.graph->operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) { + const auto &orig_tgraph = lowered_graph->trainable_graph(); + if (orig_tgraph.backward_operands().exist(index)) + { + const auto &bwd_operand = orig_tgraph.backward_operands().at(index); + auto new_bwd_operand = std::make_unique<ir::Operand>(bwd_operand); + auto gen_index = tgraph->addBackwardOperand(index, std::move(new_bwd_operand)); + UNUSED_RELEASE(gen_index); + assert(gen_index == index); + } }); - } - backend::TensorManagerSet tensor_mgrs = createTensorManagerSet(tensor_builders); - auto exec = new exec::LinearExecutor{ - std::move(lowered_graph), input_tensors, output_tensors, tensor_regs, - std::move(tensor_mgrs), std::move(code_map), order}; + // Remove outputs of whole graph from external_operands + auto external_operands = data.external_operands; + for (const auto &index : lowered_graph->trainable_graph().getOutputs()) + { + if (external_operands.contains(index)) + external_operands.remove(index); + } - if (!options.trace_filepath.empty()) - { - std::unique_ptr<exec::IExecutionObserver> ctp = - std::make_unique<exec::ChromeTracingObserver>(options.trace_filepath, exec->graph()); - exec->addObserver(std::move(ctp)); + // Set trainable context data + backend::train::TrainableContextData tdata; + tdata.tgraph = std::move(tgraph); + tdata.op_order = std::move(data.op_order); + tdata.external_operands = std::move(external_operands); + tdata.operand_layouts = std::move(data.operand_layouts); + tdata.custom_kernel_builder = std::move(data.custom_kernel_builder); + tdata.is_linear_executor = data.is_linear_executor; + tdata.optim_info = training_info.optimizerInfo(); + + // TODO Remove dynamic_cast + const auto backend = pair.first; + const auto tbackend = dynamic_cast<const backend::train::ITrainableBackend *>(backend); + if (!tbackend) + { + throw std::runtime_error("ExecutorFactory: Invalid backend - TrainableExecutor does not " + "support non-trainble backends"); + } + tbackend_contexts.emplace(backend, tbackend->newContext(std::move(tdata))); } + base_backend_contexts.clear(); - return exec; -} + train::TensorRegistries tensor_regs{tbackend_contexts, true}; -exec::IExecutor *ExecutorFactory::createDataflowExecutor( - std::unique_ptr<compiler::LoweredGraph> lowered_graph, const compiler::CompilerOptions &options, - const std::shared_ptr<exec::ExecutorMap> &executor_map, bool parallel) -{ - const auto &backend_contexts = lowered_graph->backend_contexts(); - - initializeBackendContext(lowered_graph.get()); + initializeSubgraphIOTensors( + *lowered_graph, tbackend_contexts, + (lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs()) | + ir::Remove::DUPLICATED | ir::Remove::UNDEFINED); + // linearize for forwarding auto order = Linear::linearize(*lowered_graph); - runTensorRegistration(lowered_graph.get(), order); + VERBOSE(ExecutorFactory) << "Linearize for forwarding order" << std::endl; + Linear::dump(*lowered_graph, order); + + // linearize for backwarding + auto backward_order = lowered_graph->trainable_graph().essentialBackwardOrder(); + VERBOSE(ExecutorFactory) << "Linearize for backwarding order" << std::endl; + Linear::dump(*lowered_graph, backward_order); - std::vector<std::shared_ptr<backend::ITensor>> input_tensors; - std::vector<std::shared_ptr<backend::ITensor>> output_tensors; - if (options.is_primary_subgraph) + for (auto &&pair : tbackend_contexts) { - input_tensors = initializeModelIOTensors(*lowered_graph, lowered_graph->graph().getInputs()); - output_tensors = initializeModelIOTensors(*lowered_graph, lowered_graph->graph().getOutputs()); + pair.second->genTensors(); } - TensorBuilders tensor_builders{lowered_graph->backend_contexts(), true}; - TensorRegistries tensor_regs{lowered_graph->backend_contexts(), true}; - - // To make tensors never be deallocated, this is a workaround to use static memory planner - for (auto &tensor_builder : tensor_builders) + for (auto &&pair : tbackend_contexts) { - lowered_graph->graph().operands().iterate( - [&](const ir::OperandIndex &ind, const ir::Operand &) { - if (tensor_builder->isRegistered(ind)) - { - tensor_builder->notifyFirstUse(ind); - } - }); + auto tctx = pair.second.get(); + tctx->genTrainingTensors(); } - for (auto &tensor_builder : tensor_builders) + prepareMigrantTensors(*lowered_graph, tbackend_contexts); + + // Give some runtime objects to builtin KernelGenerator + for (auto &&pair : tbackend_contexts) { - tensor_builder->prepare(); + auto builtin_context = + dynamic_cast<backend::builtin::train::BackendContext *>(pair.second.get()); + if (builtin_context != nullptr) + { + auto builtin_kernel_gen = builtin_context->kernel_gen; + builtin_kernel_gen->setTensorRegistries(tensor_regs); + builtin_kernel_gen->setWholeGraphOutputs(lowered_graph->trainable_graph().getOutputs()); + } } - prepareExternalTensors(*lowered_graph); + // Adjust the order of backends for the upcoming iteration + auto ordered_contexts = + onert::orderBackendContext<backend::train::TrainableBackendContext>(tbackend_contexts); - ExecutionBuilder builder; + // TODO Remove this simulation + // Simulate the execution for deallocation of tensors + std::unordered_map<ir::OperationIndex, DeallocList> dealloc_list_map; + { + ir::OperandIndexMap<uint32_t> uses_map; + ir::OperandIndexSequence constants; - // Generate kernels - lowered_graph->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &op_seq_index, - const ir::OpSequence &op_seq) { - auto lower_info = lowered_graph->getLowerInfo(op_seq_index); - auto kernel_gen = lowered_graph->backend_contexts().at(lower_info->backend())->kernel_gen; - // Set TensorBuilderSet and ExecutorMap to kernel_gen of control flow - auto cf_kernel_gen = dynamic_cast<backend::controlflow::KernelGenerator *>(kernel_gen.get()); - if (cf_kernel_gen != nullptr) + auto model_io = + (graph.getInputs() + graph.getOutputs()) | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + + // Prepare scanning + graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) { + uses_map[ind] = obj.getUses().size(); + + if (obj.isConstant()) + constants.append(ind); + }); + + // A trick to consider constants as an execption + for (const auto &ind : constants) { - assert(cf_kernel_gen != nullptr); - cf_kernel_gen->setTensorRegistries(tensor_regs); - cf_kernel_gen->setExecutorMap(executor_map); + uses_map[ind]++; } - auto fn_seq = kernel_gen->generate(op_seq); - if (options.he_profiling_mode) + + for (const auto &op_ind : order) { - fn_seq->wrap<SyncFunction>(lower_info->backend()->config()); + const auto &op = graph.operations().at(op_ind); + auto op_inputs = op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED; + auto op_outputs = op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED; + + for (const auto &ind : op_inputs) + { + const auto &operand = graph.operands().at(ind); + assert(uses_map.find(ind) != uses_map.end()); + assert(uses_map[ind] > 0); + uses_map[ind]--; + if (uses_map[ind] == 0 && !operand.info().isVariable() && !model_io.contains(ind)) + { + dealloc_list_map[op_ind].emplace_back(tensor_regs.getITensor(ind)); + } + } } - builder.append(op_seq_index, {&op_seq, lower_info, std::move(fn_seq)}); - }); - for (const auto &tensor_builder : tensor_builders) - { - tensor_builder->allocate(); - } + // Dispose and validate + for (const auto &ind : constants) + { + --uses_map[ind]; + } - for (auto &pair : backend_contexts) - { - pair.second->initConsts(); + assert( + std::all_of(uses_map.begin(), uses_map.end(), + [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; })); } - lowered_graph->graph().operands().iterate( - [](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); }); - - auto code_map = builder.releaseCodeMap(); - - for (auto &it : code_map) + // Check back propagation tensors { - auto op_seq_index = it.first; - auto &fn_seq = it.second.fn_seq; - - fn_seq->iterate([&](exec::IFunction &ifunc) { - ifunc.prepare(); - auto backend = lowered_graph->getLowerInfo(op_seq_index)->backend(); - auto tensor_builder = lowered_graph->backend_contexts().at(backend)->tensor_builder; - tensor_builder->postFunctionPrepare(); - }); + // TODO Support multiple subgraphs + // Check if the back propagation tensors corresponding to inputs of model are nullptr + // NOTE The back propagation tensors corresponding to inputs of model are for inputs of + // PermuteLayers + // and they are nullptr and because they are meaningless. + assert(std::all_of( + lowered_graph->trainable_graph().getInputs().begin(), + lowered_graph->trainable_graph().getInputs().end(), + [&](const auto &input_idx) { return tensor_regs.getBackPropITensor(input_idx) == nullptr; })); + + // Check if the back propagation tensors corresponding to outputs of model exist + assert(std::all_of(lowered_graph->trainable_graph().getOutputs().begin(), + lowered_graph->trainable_graph().getOutputs().end(), + [&](const auto &output_idx) { + return tensor_regs.getBackPropITensor(output_idx) == nullptr; + })); } - backend::TensorManagerSet tensor_mgrs = createTensorManagerSet(tensor_builders); - - exec::ExecutorBase *exec = nullptr; - if (parallel) - { - exec = new exec::ParallelExecutor{std::move(lowered_graph), input_tensors, - output_tensors, tensor_regs, - std::move(tensor_mgrs), std::move(code_map)}; - } - else + train::TrainableCodeMap code_map; + // Generate kernels + for (auto &&pair : ordered_contexts) { - auto dataflow_exec = new exec::DataflowExecutor{std::move(lowered_graph), input_tensors, - output_tensors, tensor_regs, - std::move(tensor_mgrs), std::move(code_map)}; - if (options.he_profiling_mode) + auto codes = pair.second->genKernels(); + for (auto &&pair : codes) { - std::vector<const backend::Backend *> backends; - for (const auto &pair : backend_contexts) - { - backends.push_back(pair.first); - } - auto et = std::make_shared<exec::ExecTime>(backends); - std::unique_ptr<exec::IExecutionObserver> obs = - std::make_unique<exec::ProfileObserver>(et, dataflow_exec->graph()); - dataflow_exec->addObserver(std::move(obs)); + auto &op_ind = pair.first; + auto &tn_seq = pair.second; + auto &op = lowered_graph->trainable_graph().operation(op_ind); + auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind); + + assert(code_map.find(op_ind) == code_map.end()); + code_map.insert( + {op_ind, train::TrainableCodeAndInfo{op_ind, &op, lower_info, std::move(tn_seq)}}); } - exec = dataflow_exec; } - if (!options.trace_filepath.empty()) + if (order.size() != code_map.size()) + { + throw std::runtime_error("ExecutorFactory: Some kernels are not generated"); + } + + auto exec = new exec::train::TrainableExecutor{std::move(lowered_graph), + std::move(tbackend_contexts), + tensor_regs, + std::move(code_map), + order, + backward_order, + tracing_ctx, + training_info.lossInfo()}; + + if (!options->workspace_dir.empty()) { - std::unique_ptr<exec::IExecutionObserver> ctp = - std::make_unique<exec::ChromeTracingObserver>(options.trace_filepath, exec->graph()); - exec->addObserver(std::move(ctp)); + exec->addObserver( + std::make_unique<exec::TracingObserver>(options->workspace_dir, exec->graph(), tracing_ctx)); } + // TODO Support MINMAX_H5DUMPER return exec; } diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.h b/runtime/onert/core/src/compiler/ExecutorFactory.h index b8893c03b..1b9bd4ab6 100644 --- a/runtime/onert/core/src/compiler/ExecutorFactory.h +++ b/runtime/onert/core/src/compiler/ExecutorFactory.h @@ -17,18 +17,32 @@ #ifndef __ONERT_COMPILER_EXECUTOR_FACTORY_H__ #define __ONERT_COMPILER_EXECUTOR_FACTORY_H__ -#include <unordered_map> +#include "TensorRegistries.h" #include "backend/ITensor.h" -#include "exec/IExecutor.h" +#include "backend/train/TrainableBackendContext.h" #include "compiler/LoweredGraph.h" -#include "TensorRegistries.h" +#include "compiler/train/LoweredTrainableGraph.h" +#include "exec/IExecutors.h" +#include "ir/train/TrainingInfo.h" + +#include <deque> +#include <unordered_map> namespace onert { namespace compiler { +// TODO Change to a better name +struct ExecutorFactoryArgs +{ + const util::TracingCtx *tracing_ctx; + const compiler::CompilerOptions *options; + ir::ModelIndex model_index; + std::shared_ptr<backend::custom::IKernelBuilder> custom_kernel_builder; +}; + class ExecutorFactory { public: @@ -36,35 +50,52 @@ public: public: exec::IExecutor *create(std::unique_ptr<compiler::LoweredGraph> lowered_graph, - const compiler::CompilerOptions &options, - const std::shared_ptr<exec::ExecutorMap> &executor_map); + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args); + + // TODO Unify create() + exec::IExecutor *create(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args, + const ir::train::TrainingInfo &training_info); private: ExecutorFactory(); private: - static void initializeBackendContext(compiler::LoweredGraph *lowered_graph); - static void runTensorRegistration(compiler::LoweredGraph *lowered_graph, - const std::vector<ir::OpSequenceIndex> &order); - static std::vector<std::shared_ptr<backend::ITensor>> - initializeModelIOTensors(compiler::LoweredGraph &lowered_graph, - const ir::OperandIndexSequence &indices); - static void prepareExternalTensors(compiler::LoweredGraph &lowered_graph); + static void prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph, + const backend::BackendContexts &backend_contexts); + static void prepareBuiltinBackend(const TensorRegistries &tensor_regs, + const std::shared_ptr<exec::IExecutors> &executors, + const backend::BackendContexts &backend_contexts, + const ir::ModelIndex &index); + static std::deque<std::pair<const backend::Backend *, backend::BackendContext *>> + orderBackendContext(const backend::BackendContexts &backend_contexts); + static exec::IExecutor * createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, - const compiler::CompilerOptions &options, - const std::shared_ptr<exec::ExecutorMap> &executor_map); + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args); static exec::IExecutor * createDataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, - const compiler::CompilerOptions &options, - const std::shared_ptr<exec::ExecutorMap> &executor_map, bool parallel); + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args, bool parallel); + // TODO Unify prepareMigrantTensors + static void + prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph, + const backend::train::TrainableBackendContexts &backend_contexts); + static exec::IExecutor * + createTrainableExecutor(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args, + const ir::train::TrainingInfo &training_info); private: - std::unordered_map<std::string, std::function<exec::IExecutor *( - std::unique_ptr<compiler::LoweredGraph>, - const compiler::CompilerOptions &options, - const std::shared_ptr<exec::ExecutorMap> &executor_map)>> - _map; + std::unordered_map< + std::string, std::function<exec::IExecutor *(std::unique_ptr<compiler::LoweredGraph>, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args)>> + _map; }; } // namespace compiler diff --git a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc index 23a6a253d..ce9b09c2d 100644 --- a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc +++ b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc @@ -14,6 +14,8 @@ * limitations under the License. */ +#if 0 // This file is temporarily unused + #include "Fp32ToFp16Converter.h" #include "ir/operation/ConvertFp32ToFp16.h" #include "ir/operation/ConvertFp16ToFp32.h" @@ -45,7 +47,7 @@ namespace compiler { Fp32ToFp16Converter::Fp32ToFp16Converter(compiler::LoweredGraph &lowered_graph) - : _lowered_graph{lowered_graph} + : _lowered_graph{lowered_graph} { VERBOSE(Fp32ToFp16Converter) << "Fp16 Enable on" << std::endl; } @@ -177,26 +179,26 @@ void Fp32ToFp16Converter::run() void Fp32ToFp16Converter::appendOpSequences() { _lowered_graph.op_seqs().iterate( - [&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) { - const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind); - assert(lower_info != nullptr); - - // For now, the only acl_cl supports fully fp16 type - // TODO Support fp16 on acl_neon. Current acl_neon supports the only reshape and concat - // operations. - // To do this, we could check the support by `operation by operation`. After that, we - // would partition an op_seq if it contains unsupported operations. - if (lower_info->backend()->config()->id() != kAclClBackendConfigId) - return; - - // OpSeq's input set should be included in the first operation's input set or - // OpSeq's output set should be included in the last operation's output set - assert(checkOperandsOfOpSequence(op_seq)); - - // Append converting OpSequence for fp16 but all operands' types are not fp16 still. - appendNewOpSeqForConvertFp32ToFp16(op_seq_ind, op_seq); - appendNewOpSeqForConvertFp16ToFp32(op_seq_ind, op_seq); - }); + [&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) { + const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind); + assert(lower_info != nullptr); + + // For now, the only acl_cl supports fully fp16 type + // TODO Support fp16 on acl_neon. Current acl_neon supports the only reshape and concat + // operations. + // To do this, we could check the support by `operation by operation`. After that, we + // would partition an op_seq if it contains unsupported operations. + if (lower_info->backend()->config()->id() != kAclClBackendConfigId) + return; + + // OpSeq's input set should be included in the first operation's input set or + // OpSeq's output set should be included in the last operation's output set + assert(checkOperandsOfOpSequence(op_seq)); + + // Append converting OpSequence for fp16 but all operands' types are not fp16 still. + appendNewOpSeqForConvertFp32ToFp16(op_seq_ind, op_seq); + appendNewOpSeqForConvertFp16ToFp32(op_seq_ind, op_seq); + }); } // @@ -253,7 +255,7 @@ void Fp32ToFp16Converter::appendNewOpSeqForConvertFp32ToFp16(const ir::OpSequenc const auto new_op_seq_ind = newOpSequence(op_seq_ind, new_node_ind); // set new lower_info for op_seq - setNewOpSequenceLowerInfo(op_seq_ind, new_op_seq_ind); + setNewOperationLowerInfo(op_seq_ind, new_op_seq_ind); _list_fp32_to_fp16.insert(new_op_seq_ind); @@ -326,7 +328,7 @@ void Fp32ToFp16Converter::appendNewOpSeqForConvertFp16ToFp32(const ir::OpSequenc auto new_op_seq_ind = newOpSequence(op_seq_ind, new_node_ind); // set new lower_info for op_seq - setNewOpSequenceLowerInfo(op_seq_ind, new_op_seq_ind); + setNewOperationLowerInfo(op_seq_ind, new_op_seq_ind); _list_fp16_to_fp32.insert(new_op_seq_ind); @@ -372,16 +374,16 @@ void Fp32ToFp16Converter::optimize() void Fp32ToFp16Converter::convertOperands() { _lowered_graph.op_seqs().iterate( - [&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) { - const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind); - assert(lower_info != nullptr); - // For now, the only acl_cl supports fully fp16 - if (lower_info->backend()->config()->id() != kAclClBackendConfigId) - return; - - // Convert input,output operands' type to fp16 - convertOperandsOfOpSequence(op_seq); - }); + [&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) { + const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind); + assert(lower_info != nullptr); + // For now, the only acl_cl supports fully fp16 + if (lower_info->backend()->config()->id() != kAclClBackendConfigId) + return; + + // Convert input,output operands' type to fp16 + convertOperandsOfOpSequence(op_seq); + }); } void Fp32ToFp16Converter::convertOperandsOfOpSequence(ir::OpSequence &op_seq) @@ -391,10 +393,10 @@ void Fp32ToFp16Converter::convertOperandsOfOpSequence(ir::OpSequence &op_seq) const auto &op_seq_inputs = _lowered_graph.graph().getInputs(); const auto &op_seq_outputs = _lowered_graph.graph().getOutputs(); - for (auto &op_idx : op_seq) + for (const auto &op_idx : op_seq) { const auto &node = operations.at(op_idx); - for (auto &ind : node.getInputs() | ir::Remove::UNDEFINED) + for (const auto &ind : node.getInputs() | ir::Remove::UNDEFINED) { if (node.opcode() == ir::OpCode::ConvertFp32ToFp16 || op_seq_inputs.contains(ind)) continue; @@ -405,10 +407,10 @@ void Fp32ToFp16Converter::convertOperandsOfOpSequence(ir::OpSequence &op_seq) obj.type(ir::DataType::FLOAT16); - VERBOSE(Fp32ToFp16Converter) << "Input Operand #" << ind.value() << ": fp16" << std::endl; + VERBOSE(Fp32ToFp16Converter) << "Input Operand " << ind << ": fp16" << std::endl; } - for (auto &ind : node.getOutputs()) + for (const auto &ind : node.getOutputs()) { if (node.opcode() == ir::OpCode::ConvertFp16ToFp32 || op_seq_outputs.contains(ind)) continue; @@ -419,7 +421,7 @@ void Fp32ToFp16Converter::convertOperandsOfOpSequence(ir::OpSequence &op_seq) obj.type(ir::DataType::FLOAT16); - VERBOSE(Fp32ToFp16Converter) << "Output Operand #" << ind.value() << ": fp16" << std::endl; + VERBOSE(Fp32ToFp16Converter) << "Output Operand " << ind << ": fp16" << std::endl; } } } @@ -444,7 +446,7 @@ void Fp32ToFp16Converter::convertDatas() obj.data(std::move(new_data)); obj.type(ir::DataType::FLOAT16); - VERBOSE(Fp32ToFp16Converter) << "Constant Operand #" << ind.value() << ": fp16" << std::endl; + VERBOSE(Fp32ToFp16Converter) << "Constant Operand " << ind << ": fp16" << std::endl; } }); } @@ -513,23 +515,23 @@ ir::OperandIndex Fp32ToFp16Converter::newCopiedOperand(const ir::OperandIndex &o void Fp32ToFp16Converter::setNewOperandLowerInfo(const ir::OpSequenceIndex &op_seq_ind, const ir::OperandIndex &new_op_ind) { - const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind); + const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind); assert(lower_info != nullptr); - auto new_lower_info = std::make_unique<ir::operand::LowerInfo>(); - auto permute_factor = ir::operand::PermuteFactor(lower_info->backend(), lower_info->layout()); + auto new_lower_info = std::make_unique<compiler::OperandLowerInfo>(); + auto permute_factor = compiler::PermuteFactor(lower_info->backend(), lower_info->layout()); new_lower_info->addDefPermuteFactor(permute_factor); new_lower_info->addUsePermuteFactor(permute_factor); _lowered_graph.setLowerInfo(new_op_ind, std::move(new_lower_info)); } -void Fp32ToFp16Converter::setNewOpSequenceLowerInfo(const ir::OpSequenceIndex &op_seq_ind, - const ir::OpSequenceIndex &new_op_seq_ind) +void Fp32ToFp16Converter::setNewOperationLowerInfo(const ir::OpSequenceIndex &op_seq_ind, + const ir::OpSequenceIndex &new_op_seq_ind) { - const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind); + const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind); assert(lower_info != nullptr); auto new_lower_info = - std::make_unique<ir::operation::LowerInfo>(lower_info->backend(), lower_info->layout()); + std::make_unique<compiler::OperationLowerInfo>(lower_info->backend(), lower_info->layout()); _lowered_graph.setLowerInfo(new_op_seq_ind, std::move(new_lower_info)); } @@ -600,7 +602,7 @@ Fp32ToFp16Converter::newOperationConvertFp32ToFp16(const ir::OperandIndex &op_se auto &new_op_obj = operands.at(new_op_ind); std::unique_ptr<ir::Operation> new_node( - new ir::operation::ConvertFp32ToFp16({op_seq_input_ind}, {new_op_ind})); + new ir::operation::ConvertFp32ToFp16({op_seq_input_ind}, {new_op_ind})); const auto new_node_ind = operations.push(std::move(new_node)); input_obj.insertUse(new_node_ind); @@ -620,7 +622,7 @@ Fp32ToFp16Converter::newOperationConvertFp16ToFp32(const ir::OperandIndex &op_se auto &new_op_obj = operands.at(new_op_ind); std::unique_ptr<ir::Operation> new_node( - new ir::operation::ConvertFp16ToFp32({new_op_ind}, {op_seq_output_ind})); + new ir::operation::ConvertFp16ToFp32({new_op_ind}, {op_seq_output_ind})); const auto new_node_ind = operations.push(std::move(new_node)); new_op_obj.insertUse(new_node_ind); @@ -633,7 +635,7 @@ ir::OpSequenceIndex Fp32ToFp16Converter::newOpSequence(const ir::OpSequenceIndex const ir::OperationIndex &node_index) { auto &node = _lowered_graph.graph().operations().at(node_index); - const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind); + const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind); assert(lower_info != nullptr); auto layout = lower_info->layout(); @@ -745,7 +747,7 @@ Fp32ToFp16Converter::findOpSequencesContiguous(const InputToOpSeqs &input_to_op_ // | | // [OPERATION] [OPERATION] // - for (auto &op_seq_ind : found_input_in_op_seqs->second) + for (const auto &op_seq_ind : found_input_in_op_seqs->second) { auto found_in_fp32_to_fp16 = _list_fp32_to_fp16.find(op_seq_ind); if (found_in_fp32_to_fp16 != _list_fp32_to_fp16.end()) @@ -759,9 +761,8 @@ Fp32ToFp16Converter::findOpSequencesContiguous(const InputToOpSeqs &input_to_op_ opseq_map_to_delete[op_seq_ind_fp16_to_fp32].insert(op_seq_ind); } - VERBOSE(Fp32ToFp16Converter) - << "Contiguous from OpSeq#" << op_seq_ind_fp16_to_fp32.value() << "(ToFp32)" - << " to OpSeq#" << op_seq_ind.value() << "(ToFp16)" << std::endl; + VERBOSE(Fp32ToFp16Converter) << "Contiguous from " << op_seq_ind_fp16_to_fp32 << "(ToFp32)" + << " to " << op_seq_ind << "(ToFp16)" << std::endl; } } } @@ -775,7 +776,7 @@ Fp32ToFp16Converter::InputToOpSeqs Fp32ToFp16Converter::prepareInputToOpSeqs() c InputToOpSeqs input_to_op_seqs; op_seqs.iterate([&](const ir::OpSequenceIndex &op_seq_idx, const ir::OpSequence &op_seq) { - for (auto input : op_seq.getInputs() | ir::Remove::UNDEFINED) + for (auto &&input : op_seq.getInputs() | ir::Remove::UNDEFINED) { auto it = input_to_op_seqs.find(input); if (it == input_to_op_seqs.end()) @@ -798,13 +799,13 @@ Fp32ToFp16Converter::getListOpSequences(const OpSeqIndexToOpSeqIndexList &opseq_ OpSeqIndexList list; for (const auto &it : opseq_map_to_delete) { - auto &opseq_ind_fp16_to_fp32 = it.first; + const auto &opseq_ind_fp16_to_fp32 = it.first; if (list.find(opseq_ind_fp16_to_fp32) == list.end()) { list.emplace(opseq_ind_fp16_to_fp32); } - for (auto &opseq_ind_fp32_to_fp16 : it.second) + for (const auto &opseq_ind_fp32_to_fp16 : it.second) { if (list.find(opseq_ind_fp32_to_fp16) == list.end()) { @@ -842,7 +843,7 @@ Fp32ToFp16Converter::findOperationsToDelete(const OpSeqIndexList &list_to_delete } void Fp32ToFp16Converter::manipulateContiguousOpSequences( - const InputToOpSeqs &input_to_op_seqs, const OpSeqIndexToOpSeqIndexList &opseq_map_to_delete) + const InputToOpSeqs &input_to_op_seqs, const OpSeqIndexToOpSeqIndexList &opseq_map_to_delete) { auto &op_seqs = _lowered_graph.op_seqs(); @@ -861,14 +862,14 @@ void Fp32ToFp16Converter::manipulateContiguousOpSequences( // | // [OPERATION] // op_seq_ind_next_to_fp16 // - for (auto it : opseq_map_to_delete) + for (auto &&it : opseq_map_to_delete) { // fp16_to_fp32's input/output num is always 1 auto &op_seq_ind_fp16_to_fp32 = it.first; auto &op_seq_fp16_to_fp32 = op_seqs.at(op_seq_ind_fp16_to_fp32); auto &input_ind_fp16_to_fp32 = op_seq_fp16_to_fp32.getInputs().at(0); - for (auto &op_seq_ind_fp32_to_fp16 : it.second) + for (const auto &op_seq_ind_fp32_to_fp16 : it.second) { auto &op_seq_fp32_to_fp16 = op_seqs.at(op_seq_ind_fp32_to_fp16); assert(op_seq_fp32_to_fp16.size() == 1); @@ -878,7 +879,7 @@ void Fp32ToFp16Converter::manipulateContiguousOpSequences( auto found_next_to_fp16 = input_to_op_seqs.find(output_ind_fp32_to_fp16); assert(found_next_to_fp16 != input_to_op_seqs.end()); - for (auto &op_seq_ind_next_to_fp16 : found_next_to_fp16->second) + for (const auto &op_seq_ind_next_to_fp16 : found_next_to_fp16->second) { manipulateInput(op_seq_ind_next_to_fp16, output_ind_fp32_to_fp16, input_ind_fp16_to_fp32); } @@ -894,61 +895,62 @@ void Fp32ToFp16Converter::manipulateContiguousOpSequences( } void Fp32ToFp16Converter::deleteContiguousOpSequences( - const OpSeqIndexList &list_to_delete_op_seqs, - const ir::OperandIndexSequence &list_to_delete_ops) + const OpSeqIndexList &list_to_delete_op_seqs, const ir::OperandIndexSequence &list_to_delete_ops) { auto &operands = _lowered_graph.graph().operands(); auto &operations = _lowered_graph.graph().operations(); auto &op_seqs = _lowered_graph.op_seqs(); - for (auto &op_seq_ind : list_to_delete_op_seqs) + for (const auto &op_seq_ind : list_to_delete_op_seqs) { auto &op_seq = op_seqs.at(op_seq_ind); assert(op_seq.size() == 1); - VERBOSE(Fp32ToFp16Converter) << "Delete OpSeq #" << op_seq_ind.value() << std::endl; + VERBOSE(Fp32ToFp16Converter) << "Delete OpSeq " << op_seq_ind << std::endl; auto &first_node_ind = op_seq.operations().at(0); auto &first_node = operations.at(first_node_ind); assert(first_node.opcode() == ir::OpCode::ConvertFp32ToFp16 || first_node.opcode() == ir::OpCode::ConvertFp16ToFp32); - VERBOSE(Fp32ToFp16Converter) << "Delete Node #" << first_node_ind.value() << std::endl; + VERBOSE(Fp32ToFp16Converter) << "Delete Node " << first_node_ind << std::endl; // Uses - for (auto &ind : first_node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + for (const auto &ind : first_node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) { auto &obj = operands.at(ind); obj.removeUse(first_node_ind); - VERBOSE(Fp32ToFp16Converter) << "Operand #" << ind.value() << "'s Use(Node#" - << first_node_ind.value() << ") is removed" << std::endl; + VERBOSE(Fp32ToFp16Converter) + << "Operand " << ind << "'s Use(Node" << first_node_ind << ") is removed" << std::endl; } // Def - for (auto &ind : first_node.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + for (const auto &ind : first_node.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) { auto &obj = operands.at(ind); assert(obj.getDef() == first_node_ind); obj.unsetDef(); - VERBOSE(Fp32ToFp16Converter) << "Operand #" << ind.value() << "'s Def(Node#" - << first_node_ind.value() << ") is removed" << std::endl; + VERBOSE(Fp32ToFp16Converter) + << "Operand " << ind << "'s Def(Node" << first_node_ind << ") is removed" << std::endl; } // Operation operations.remove(first_node_ind); - VERBOSE(Fp32ToFp16Converter) << "Node#" << first_node_ind.value() << " is removed" << std::endl; + VERBOSE(Fp32ToFp16Converter) << "Node" << first_node_ind << " is removed" << std::endl; // OpSequence op_seqs.remove(op_seq_ind); - VERBOSE(Fp32ToFp16Converter) << "OpSeq#" << op_seq_ind.value() << " is removed" << std::endl; + VERBOSE(Fp32ToFp16Converter) << "OpSeq" << op_seq_ind << " is removed" << std::endl; } // Operand - for (auto &ind : list_to_delete_ops) + for (const auto &ind : list_to_delete_ops) { operands.remove(ind); - VERBOSE(Fp32ToFp16Converter) << "Operand #" << ind.value() << " is removed" << std::endl; + VERBOSE(Fp32ToFp16Converter) << "Operand " << ind << " is removed" << std::endl; } } } // namespace compiler } // namespace onert + +#endif diff --git a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.h b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.h index eeecb9846..87751ceb4 100644 --- a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.h +++ b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.h @@ -14,6 +14,8 @@ * limitations under the License. */ +#if 0 // This file is temporarily unused + #ifndef __ONERT_COMPILER_FP32_TO_FP16_CONVERTER_H__ #define __ONERT_COMPILER_FP32_TO_FP16_CONVERTER_H__ @@ -64,8 +66,8 @@ private: void setNewOperandLowerInfo(const ir::OpSequenceIndex &op_seq_ind, const ir::OperandIndex &new_op_ind); - void setNewOpSequenceLowerInfo(const ir::OpSequenceIndex &op_seq_ind, - const ir::OpSequenceIndex &new_op_seq_ind); + void setNewOperationLowerInfo(const ir::OpSequenceIndex &op_seq_ind, + const ir::OpSequenceIndex &new_op_seq_ind); void manipulateInput(const ir::OpSequenceIndex &op_seq_ind, const ir::OperandIndex &op_seq_input_ind, @@ -99,3 +101,5 @@ private: } // namespace onert #endif // __ONERT_COMPILER_FP32_TO_FP16_CONVERTER_H__ + +#endif diff --git a/runtime/onert/core/src/compiler/HEScheduler.cc b/runtime/onert/core/src/compiler/HEScheduler.cc index 5653b090e..2d04d42ce 100644 --- a/runtime/onert/core/src/compiler/HEScheduler.cc +++ b/runtime/onert/core/src/compiler/HEScheduler.cc @@ -14,34 +14,32 @@ * limitations under the License. */ -#include "ir/Operand.h" -#include "compiler/HEScheduler.h" -#include "ir/Graph.h" -#include "util/ConfigSource.h" +#include "HEScheduler.h" + #include "compiler/BackendResolver.h" +#include "ir/Graph.h" #include "util/logging.h" -#include "util/Utils.h" -#include "exec/FunctionSequence.h" + #include <cassert> #include <cmath> -#include <chrono> -namespace onert +namespace { -namespace compiler -{ -static uint32_t getOperationsFlattenedIOSize(const ir::Graph &graph, const ir::Operation &node) +using namespace onert; + +uint32_t getOperationsFlattenedIOSize(const ir::Graph &graph, const ir::IOperation &node) { uint32_t size = 0; - for (const auto &ind : (node.getInputs() | ir::Remove::UNDEFINED) + node.getOutputs()) + for (const auto &ind : + (node.getInputs() | ir::Remove::UNDEFINED) + (node.getOutputs() | ir::Remove::UNDEFINED)) { size += graph.operands().at(ind).info().total_size(); } return size; } -static bool isQuant(const ir::Graph &graph, const ir::Operation &node) +bool isQuant(const ir::Graph &graph, const ir::IOperation &node) { for (const auto &input : node.getInputs() | ir::Remove::UNDEFINED) { @@ -54,18 +52,11 @@ static bool isQuant(const ir::Graph &graph, const ir::Operation &node) return false; } -static bool isWorkaroundSkip(const ir::Graph &, const backend::Backend *, const ir::Operation &, - bool) -{ - // Now, there is no workaround - return false; -} - // if a node can be merged into op_seq -static bool isMergeable(const ir::Graph &graph, const ir::Operation &node) +bool isMergeable(const ir::Graph &graph, const ir::IOperation &node) { size_t prev_op_cnt = 0; - for (const auto &input : node.getInputs()) + for (const auto &input : node.getInputs() | ir::Remove::UNDEFINED) { // only valid_inputs const auto &operand = graph.operands().at(input); @@ -85,15 +76,23 @@ static bool isMergeable(const ir::Graph &graph, const ir::Operation &node) return true; } +} // namespace + +namespace onert +{ + +namespace compiler +{ + void HEScheduler::scheduleShufflingBackends() { VERBOSE(HEScheduler::schedule) - << "Started task scheduling: uses all backends to get more metrics for data transfer" - << std::endl; + << "Started task scheduling: uses all backends to get more metrics for data transfer" + << std::endl; size_t backend_ind = 0; for (const auto &rank : _rank_to_op) { - VERBOSE(HEScheduler::schedule) << "scheduling (" << rank.second.value() << ")" << std::endl; + VERBOSE(HEScheduler::schedule) << "scheduling (" << rank.second << ")" << std::endl; const auto &node = _graph->operations().at(rank.second); const bool quant = isQuant(*_graph, node); const auto size = getOperationsFlattenedIOSize(*_graph, node); @@ -109,13 +108,8 @@ void HEScheduler::scheduleShufflingBackends() { backend_ind = 0; } - if (isWorkaroundSkip(*_graph, _all_backends[backend_ind], node, quant)) - { - ++backend_ind; - continue; - } const auto exec_time = - _exec_time->getOperationExecTime(_all_backends[backend_ind], node.name(), quant, size); + _exec_time->getOperationExecTime(_all_backends[backend_ind], node.name(), quant, size); // Scheduling to measure data transfer must be done after measuring all backends separately assert(exec_time != _exec_time->NOT_FOUND); if (exec_time == _exec_time->getMax()) @@ -132,7 +126,7 @@ void HEScheduler::scheduleShufflingBackends() } } -bool HEScheduler::isNodeProfiled(const ir::Operation &node) +bool HEScheduler::isNodeProfiled(const ir::IOperation &node) { const bool quant = isQuant(*_graph, node); const auto size = getOperationsFlattenedIOSize(*_graph, node); @@ -202,7 +196,7 @@ std::unique_ptr<compiler::BackendResolver> HEScheduler::schedule(const ir::Graph { // Check if profiling info about all backend/node pairs already exists bool all_nodes_are_profiled = true; - _graph->operations().iterate([&](const ir::OperationIndex &, const ir::Operation &op) { + _graph->operations().iterate([&](const ir::OperationIndex &, const ir::IOperation &op) { if (all_nodes_are_profiled) all_nodes_are_profiled = isNodeProfiled(op); }); @@ -219,7 +213,7 @@ std::unique_ptr<compiler::BackendResolver> HEScheduler::schedule(const ir::Graph ir::OperationIndexMap<bool> visited; graph.operations().iterate( - [&](const ir::OperationIndex &index, const ir::Operation &) { visited[index] = false; }); + [&](const ir::OperationIndex &index, const ir::IOperation &) { visited[index] = false; }); // for each task select the backend with the smallest earliest finishing time(eft) for (const auto &rank : _rank_to_op) { @@ -248,19 +242,20 @@ int64_t HEScheduler::getPermuteTime(const backend::Backend *src_backend, if (time != _exec_time->NOT_FOUND) return time; + // FIXME permute time is not recorded so the control reaches here always // Makes the scheduler prefer keeping computations on one backend - return size / 200; + return size / 400; } -int64_t HEScheduler::tryBackend(const ir::Operation &node, const backend::Backend *backend) +int64_t HEScheduler::tryBackend(const ir::IOperation &node, const backend::Backend *backend) { // if there is no profiling info don't use this backend during scheduling if (!_is_profiling_mode) { VERBOSE(HEScheduler::tryBackend) - << "Trying to HE schedule while there is no profiling info for " << node.name() - << " on backend " << backend->config()->id() << ". So this backend won't be used. " - << std::endl; + << "Trying to HE schedule while there is no profiling info for " << node.name() + << " on backend " << backend->config()->id() << ". So this backend won't be used. " + << std::endl; _is_supported[backend][node.name()] = false; return _exec_time->getMax(); } @@ -291,10 +286,10 @@ void HEScheduler::makeRank() VERBOSE(HEScheduler::makeRank) << "task prioritizing" << std::endl; _graph->operations().iterate( - [&](const ir::OperationIndex &index, const ir::Operation &) { DFSMaxRank(index); }); + [&](const ir::OperationIndex &index, const ir::IOperation &) { DFSMaxRank(index); }); // Check that ranks are calculated for all operations(nodes) - _graph->operations().iterate([&](const ir::OperationIndex &index, const ir::Operation &) { + _graph->operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &) { UNUSED_RELEASE(index); assert(_op_to_rank->find(index) != _op_to_rank->end()); }); @@ -360,8 +355,8 @@ int64_t HEScheduler::DFSMaxRank(const ir::OperationIndex &index) assert(rank >= 0); _rank_to_op.emplace(rank, index); _op_to_rank->emplace(index, rank); - VERBOSE(HEScheduler::DFSMaxRank) << "rank of operation (" << index.value() << ")" << node.name() - << " is " << rank << std::endl; + VERBOSE(HEScheduler::DFSMaxRank) + << "rank of operation (" << index << ")" << node.name() << " is " << rank << std::endl; return rank; } @@ -370,7 +365,7 @@ int64_t HEScheduler::DFSChildrenMaxRank(const ir::OperationIndex &index) { const auto &node = _graph->operations().at(index); int64_t max_child_rank = 0; - for (const auto &output : node.getOutputs()) + for (const auto &output : node.getOutputs() | ir::Remove::UNDEFINED) { const auto &operand = _graph->operands().at(output); const bool quant = operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM; @@ -384,9 +379,9 @@ int64_t HEScheduler::DFSChildrenMaxRank(const ir::OperationIndex &index) { continue; } - // TODO Change it to controlflow backend + // TODO Change it to builtin backend auto transfer_cost = - getPermuteTime(backend, other_backend, quant, operand.info().total_size()); + getPermuteTime(backend, other_backend, quant, operand.info().total_size()); avg_transfer_cost += transfer_cost; } } @@ -403,7 +398,7 @@ int64_t HEScheduler::DFSChildrenMaxRank(const ir::OperationIndex &index) int64_t HEScheduler::backendAvailableTime(const backend::Backend *backend, const int64_t &starting_time, const int64_t &time_amount) { - const auto backend_times = _backends_avail_time.at(backend); + const auto &backend_times = _backends_avail_time.at(backend); // finishing and starting times of an op, that will come after current op auto next_op_fst = backend_times.upper_bound(starting_time); // finishing time of an op, that will come before current op @@ -419,7 +414,7 @@ int64_t HEScheduler::backendAvailableTime(const backend::Backend *backend, bool HEScheduler::schedule(const ir::OperationIndex &index, const backend::Backend *parent_backend) { - VERBOSE(HEScheduler::schedule) << "scheduling (" << index.value() << ")" << std::endl; + VERBOSE(HEScheduler::schedule) << "scheduling (" << index << ")" << std::endl; int64_t eft = std::numeric_limits<int64_t>::max(), selected_exec_time = 0; const auto &node = _graph->operations().at(index); @@ -487,10 +482,6 @@ HEScheduler::ESTAndExecTime(const backend::Backend *backend, const ir::Operation { permute_fine *= 2; } - if (isWorkaroundSkip(*_graph, backend, node, quant)) - { - return {_exec_time->getMax(), _exec_time->getMax()}; - } // get average exec time of the op on this backend auto exec_time = getOpTime(backend, node.name(), quant, size); if (backend->config()->id() == "cpu" && _is_parallel_exec) @@ -506,7 +497,7 @@ HEScheduler::ESTAndExecTime(const backend::Backend *backend, const ir::Operation // Find free time for data transferring and insert it into backend taskset. This is needed: // 1. Time for multiple permutations for this node's input is found correctly // 2. If backend==cpu, then free time for this node must come after permutations - for (auto &it : transfer_st_exec_time) + for (auto &&it : transfer_st_exec_time) { if (_is_parallel_exec) { @@ -542,27 +533,27 @@ HEScheduler::ESTAndExecTime(const backend::Backend *backend, const ir::Operation if (!_is_parallel_exec) { VERBOSE(HEScheduler::ESTAndExecTime) - << "exec_time of (" << index.value() << ") " << node.name() << " quant==" << quant << " on " - << backend->config()->id() << " is " << exec_time - << " microseconds. Data transfer cost: " << total_transfer_cost << std::endl; + << "exec_time of (" << index << ") " << node.name() << " quant==" << quant << " on " + << backend->config()->id() << " is " << exec_time + << " microseconds. Data transfer cost: " << total_transfer_cost << std::endl; return {total_transfer_cost, exec_time}; } VERBOSE(HEScheduler::ESTAndExecTime) - << "exec_time of (" << index.value() << ") " << node.name() << " quant==" << quant << " on " - << backend->config()->id() << ": " << exec_time - << " microseconds. Backend available time: " << prev_op_ft - << " Parent's max eft: " << max_pred_eft - total_transfer_cost - << " data transfer cost: " << total_transfer_cost << std::endl; + << "exec_time of (" << index << ") " << node.name() << " quant==" << quant << " on " + << backend->config()->id() << ": " << exec_time + << " microseconds. Backend available time: " << prev_op_ft + << " Parent's max eft: " << max_pred_eft - total_transfer_cost + << " data transfer cost: " << total_transfer_cost << std::endl; return {prev_op_ft, exec_time}; } -int64_t HEScheduler::predMaxEFT(const backend::Backend *backend, const ir::Operation &node, +int64_t HEScheduler::predMaxEFT(const backend::Backend *backend, const ir::IOperation &node, std::multimap<int64_t, int64_t> &transfer_st_exec_time) { int64_t max_pred_eft = 0; - for (const auto &input_operand_idx : node.getInputs()) + for (const auto &input_operand_idx : node.getInputs() | ir::Remove::UNDEFINED) { const auto &input_operand = _graph->operands().at(input_operand_idx); const bool quant = input_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM; @@ -578,7 +569,7 @@ int64_t HEScheduler::predMaxEFT(const backend::Backend *backend, const ir::Opera { // Multiply operand size by 2 because size must describe input+output size int64_t transfer_cost = - getPermuteTime(parent_backend, backend, quant, input_operand.info().total_size() * 2); + getPermuteTime(parent_backend, backend, quant, input_operand.info().total_size() * 2); transfer_st_exec_time.emplace(_ops_eft.at(input_node_idx), transfer_cost); } } diff --git a/runtime/onert/core/src/compiler/HEScheduler.h b/runtime/onert/core/src/compiler/HEScheduler.h index b9cee5881..df6c07926 100644 --- a/runtime/onert/core/src/compiler/HEScheduler.h +++ b/runtime/onert/core/src/compiler/HEScheduler.h @@ -23,14 +23,16 @@ #ifndef __ONERT_COMPILER_H_E_SCHEDULER_H_ #define __ONERT_COMPILER_H_E_SCHEDULER_H_ -#include "compiler/IScheduler.h" -#include "compiler/BackendManager.h" -#include "compiler/Compiler.h" -#include "ir/Graph.h" -#include "exec/ExecTime.h" -#include "backend/Backend.h" -#include <memory> -#include "ir/OperationIndexMap.h" +#include "IScheduler.h" +#include "../backend/builtin/Config.h" +#include "../exec/ExecTime.h" + +#include <backend/Backend.h> +#include <compiler/BackendManager.h> +#include <compiler/Compiler.h> +#include <ir/Graph.h> +#include <ir/OperationIndexMap.h> + #include <map> #include <memory> @@ -50,26 +52,26 @@ public: * @param[in] model Graph model * @param[in] backend_resolver backend resolver */ - HEScheduler(const backend::BackendContexts &backend_contexts, const CompilerOptions &options) - : _is_supported{}, _backends_avail_time{}, _ops_eft{}, - _op_to_rank{std::make_shared<ir::OperationIndexMap<int64_t>>()}, - _is_profiling_mode{options.he_profiling_mode}, - _is_linear_exec{options.executor == "Linear"}, - _is_parallel_exec{options.executor == "Parallel"} + HEScheduler(const std::vector<const backend::Backend *> &backends, const CompilerOptions &options) + : _is_supported{}, _backends_avail_time{}, _ops_eft{}, + _op_to_rank{std::make_shared<ir::OperationIndexMap<int64_t>>()}, + _is_profiling_mode{options.he_profiling_mode}, _is_linear_exec{options.executor == "Linear"}, + _is_parallel_exec{options.executor == "Parallel"} { - for (auto &entry : backend_contexts) + for (auto &&entry : backends) { - if (entry.first->config()->id() == backend::controlflow::Config::ID) + if (entry->config()->id() == backend::builtin::Config::ID) continue; - _all_backends.push_back(entry.first); + _all_backends.push_back(entry); } _backend_resolver = std::make_unique<compiler::BackendResolver>(); _exec_time = std::make_unique<exec::ExecTime>(_all_backends); // Find cpu backend - auto cpu_backend_it = std::find_if( - _all_backends.begin(), _all_backends.end(), - [](const backend::Backend *backend) { return backend->config()->id() == "cpu"; }); + auto cpu_backend_it = + std::find_if(_all_backends.begin(), _all_backends.end(), [](const backend::Backend *backend) { + return backend->config()->id() == "cpu"; + }); if (cpu_backend_it == _all_backends.end()) throw std::runtime_error("HEScheduler could be used only if 'cpu' backend is available"); _cpu_backend = *cpu_backend_it; @@ -86,7 +88,7 @@ public: std::shared_ptr<ir::OperationIndexMap<int64_t>> getIndexedRanks() { return _op_to_rank; } private: - bool isNodeProfiled(const ir::Operation &); + bool isNodeProfiled(const ir::IOperation &); bool schedule(const ir::OperationIndex &, const backend::Backend *parent_backend); /** @@ -113,7 +115,7 @@ private: * * @return earliest finishing time of parent nodes */ - int64_t predMaxEFT(const backend::Backend *backend, const ir::Operation &node, + int64_t predMaxEFT(const backend::Backend *backend, const ir::IOperation &node, std::multimap<int64_t, int64_t> &transfer_st_exec_time); void makeRank(); @@ -144,7 +146,7 @@ private: void scheduleShufflingBackends(); - int64_t tryBackend(const ir::Operation &node, const backend::Backend *backend); + int64_t tryBackend(const ir::IOperation &node, const backend::Backend *backend); /** * @brief Schedule a node and its successor until: @@ -173,7 +175,7 @@ private: std::unique_ptr<exec::ExecTime> _exec_time; const ir::Graph *_graph{nullptr}; std::vector<const backend::Backend *> _all_backends; - const backend::Backend *_cpu_backend{nullptr}; // TODO Change this to controlflow_backend + const backend::Backend *_cpu_backend{nullptr}; // TODO Change this to _builtin_backend bool _is_profiling_mode; bool _is_linear_exec; bool _is_parallel_exec; diff --git a/runtime/onert/core/src/compiler/HEScheduler.test.cc b/runtime/onert/core/src/compiler/HEScheduler.test.cc new file mode 100644 index 000000000..505fbbb48 --- /dev/null +++ b/runtime/onert/core/src/compiler/HEScheduler.test.cc @@ -0,0 +1,572 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "HEScheduler.h" +#include "../exec/ExecTime.h" + +#include <ir/DataType.h> +#include <ir/InternalType.h> +#include <ir/Shape.h> +#include <ir/TypeInfo.h> +#include <ir/operation/BinaryArithmetic.h> +#include <ir/operation/FullyConnected.h> + +#include <gtest/gtest.h> + +namespace +{ +using namespace onert; +using namespace ir; +using namespace backend; +using namespace operation; +using namespace exec; + +// +// Mock backends classes +// + +struct MockConfigCPU : public IConfig +{ + std::string id() override { return "cpu"; } + bool initialize() override { return true; }; + bool supportPermutation() override { return false; } + Layout supportLayout(const IOperation &, Layout) override { return Layout::UNKNOWN; } + bool supportDynamicTensor() override { return false; } + bool supportFP16() override { return false; } +}; + +class MockBackendContext : public BackendContext +{ +public: + using BackendContext::BackendContext; + ITensorRegistry *genTensors() override { return nullptr; } + FunctionMap genKernels() override { return {}; } +}; + +struct MockBackendCPU : public Backend +{ + std::shared_ptr<IConfig> config() const override { return std::make_shared<MockConfigCPU>(); } + std::unique_ptr<BackendContext> newContext(ContextData &&data) const override + { + return std::make_unique<MockBackendContext>(this, std::move(data), nullptr); + } +}; + +struct MockConfigGPU : public IConfig +{ + std::string id() override { return "gpu"; } + bool initialize() override { return true; }; + bool supportPermutation() override { return false; } + ir::Layout supportLayout(const ir::IOperation &, ir::Layout) override + { + return ir::Layout::UNKNOWN; + } + bool supportDynamicTensor() override { return false; } + bool supportFP16() override { return false; } +}; + +struct MockBackendGPU : public Backend +{ + std::shared_ptr<IConfig> config() const override { return std::make_shared<MockConfigGPU>(); } + std::unique_ptr<BackendContext> newContext(ContextData &&data) const override + { + return std::make_unique<MockBackendContext>(this, std::move(data), nullptr); + } +}; + +struct MockConfigNPU : public IConfig +{ + std::string id() override { return "npu"; } + bool initialize() override { return true; }; + bool supportPermutation() override { return false; } + ir::Layout supportLayout(const ir::IOperation &, ir::Layout) override + { + return ir::Layout::UNKNOWN; + } + bool supportDynamicTensor() override { return false; } + bool supportFP16() override { return false; } +}; + +struct MockBackendNPU : public Backend +{ + std::shared_ptr<IConfig> config() const override { return std::make_shared<MockConfigNPU>(); } + std::unique_ptr<BackendContext> newContext(ContextData &&data) const override + { + return std::make_unique<MockBackendContext>(this, std::move(data), nullptr); + } +}; + +// +// Constants +// + +const int OPERAND_ELEMS = 268203; +const int OPERAND_SIZE = OPERAND_ELEMS * 4; +const int OPERATION_SIZE = OPERAND_SIZE * 3; + +const std::string LINEAR("Linear"); +const std::string DATAFLOW("Dataflow"); +const std::string PARALLEL("Parallel"); + +// +// Helper functions +// + +// Set executor through environment variable +void setExecutor(const std::string &executor) { setenv("EXECUTOR", executor.c_str(), true); } + +// Set profiling mode through environment variable +void setProfilingMode(const bool value) { setenv("PROFILING_MODE", value ? "1" : "0", true); } + +// Calculate operation size by addition sizes of all input and output operands +uint32_t calcOpSize(const std::shared_ptr<Graph> &graph, const OperationIndex &op_idx) +{ + uint32_t size = 0; + const auto &op = graph->operations().at(op_idx); + for (const auto &ind : op.getInputs() + op.getOutputs()) + size += graph->operands().at(ind).info().total_size(); + return size; +} + +// Set execution operation time. This method is needed since ExecutionTime has only +// 'updateOperationExecTime' method. +void setOperationExecTime(ExecTime &et, const Backend *backend, const std::string &operation, + bool quant, uint32_t op_size, int64_t time) +{ + // You shouldn't set negative time with this method since nnfw JSON deserializer can't read it + assert(time > 0); + int64_t prev_time = et.getOperationExecTime(backend, operation, quant, op_size); + int64_t time_to_set = prev_time == ExecTime::NOT_FOUND ? time : 2 * time - prev_time; + et.updateOperationExecTime(backend, operation, quant, op_size, time_to_set); + assert(et.getOperationExecTime(backend, operation, quant, op_size) == time); +} + +// Set same execution time for all given backends/operations +void setOperationsExecutionTime(const std::vector<const Backend *> &backends, + const std::vector<std::string> &op_names, + const std::vector<uint32_t> &op_sizes, int64_t exec_time) +{ + assert(op_names.size() == op_sizes.size()); + ExecTime et(backends); + for (int i = 0; i < op_names.size(); ++i) + { + for (const auto backend : backends) + setOperationExecTime(et, backend, op_names[i], false, op_sizes[i], exec_time); + } + et.storeOperationsExecTime(); +} + +// Set permute time from one backend to another. This method is needed since ExecutionTime has only +// 'updatePermuteTime' method. +void setPermutationTime(ExecTime &et, const Backend *from_backend, const Backend *to_backend, + bool quant, uint32_t op_size, int64_t time) +{ + // You shouldn't set negative time with this method since nnfw JSON deserializer can't read it + assert(time > 0); + int64_t prev_time = et.getPermuteTime(from_backend, to_backend, quant, op_size); + int64_t time_to_set = prev_time == ExecTime::NOT_FOUND ? time : 2 * time - prev_time; + et.updatePermuteTime(from_backend, to_backend, quant, op_size, time_to_set); + assert(et.getPermuteTime(from_backend, to_backend, quant, op_size) == time); +} + +// Set same permutation time between all given backends +void setPermutationsExecutionTime(const std::vector<const Backend *> &backends, + const int operand_size, const int64_t exec_time) +{ + ExecTime et(backends); + for (const auto &backend : backends) + { + for (const auto other_backend : backends) + { + if (backend == other_backend) + continue; + setPermutationTime(et, backend, other_backend, false, operand_size, exec_time); + } + } + et.storeOperationsExecTime(); +} + +// +// Functions for creating graphs +// + +using OIS = OperandIndexSequence; + +template <typename NodeT, typename... Types> +OperationIndex create(std::shared_ptr<Graph> graph, Types &&...args) +{ + auto op = std::make_unique<NodeT>(std::forward<Types>(args)...); + auto op_idx = graph->addOperation(std::move(op)); + // For now in scheduler test all operations in tested graphs has same size (for simplicity) + assert(calcOpSize(graph, op_idx) == OPERATION_SIZE); + return op_idx; +} + +// Create straight graph: Add->Sub->Mul +std::shared_ptr<Graph> createStraightGraph() +{ + auto graph = std::make_shared<Graph>(); + const TypeInfo float_op(DataType::FLOAT32); + + // Create add node + auto add_lhs_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + auto add_rhs_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + auto add_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + BinaryArithmetic::Param add_op_params{BinaryArithmetic::ArithmeticType::ADD, Activation::NONE}; + create<BinaryArithmetic>(graph, OIS{add_lhs_idx, add_rhs_idx}, OIS{add_out_idx}, add_op_params); + + // Create sub node + auto sub_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + auto sub_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + BinaryArithmetic::Param sub_op_params{BinaryArithmetic::ArithmeticType::SUB, Activation::NONE}; + create<BinaryArithmetic>(graph, OIS{add_out_idx, sub_const_idx}, OIS{sub_out_idx}, sub_op_params); + + // Create mul node + auto mul_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + auto mul_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + BinaryArithmetic::Param mul_op_params{BinaryArithmetic::ArithmeticType::MUL, Activation::NONE}; + create<BinaryArithmetic>(graph, OIS{sub_out_idx, mul_const_idx}, OIS{mul_out_idx}, mul_op_params); + + graph->verify(); + return graph; +} + +/* Create branched graph: + * [Add] + * // \\ + * [Mul1] [FC2] + * || || + * [Mul2] [FC2] + * \\ // + * [Sub] + */ +std::shared_ptr<Graph> createBranchedGraph() +{ + auto graph = std::make_shared<Graph>(); + const TypeInfo float_op(DataType::FLOAT32); + + // Create add node + auto add_lhs_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + auto add_rhs_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + auto add_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + BinaryArithmetic::Param add_op_params{BinaryArithmetic::ArithmeticType::ADD, Activation::NONE}; + create<BinaryArithmetic>(graph, OIS{add_lhs_idx, add_rhs_idx}, OIS{add_out_idx}, add_op_params); + + // Create mul1 node + auto mul1_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + auto mul1_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + BinaryArithmetic::Param mul1_op_params{BinaryArithmetic::ArithmeticType::MUL, Activation::NONE}; + create<BinaryArithmetic>(graph, OIS{add_out_idx, mul1_const_idx}, OIS{mul1_out_idx}, + mul1_op_params); + + // Create mul2 node + auto mul2_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + auto mul2_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + BinaryArithmetic::Param mul2_op_params{BinaryArithmetic::ArithmeticType::MUL, Activation::NONE}; + create<BinaryArithmetic>(graph, OIS{mul1_out_idx, mul2_const_idx}, OIS{mul2_out_idx}, + mul2_op_params); + + // Create fc1 node + auto fc1_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + auto fc1_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + FullyConnected::Param fc1_op_params{Activation::NONE}; + create<FullyConnected>(graph, OIS{add_out_idx, fc1_const_idx}, OIS{fc1_out_idx}, fc1_op_params); + + // Create fc2 node + auto fc2_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + auto fc2_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + FullyConnected::Param fc2_op_params{Activation::NONE}; + create<FullyConnected>(graph, OIS{fc1_out_idx, fc2_const_idx}, OIS{fc2_out_idx}, fc2_op_params); + + // Create sub node + auto sub_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op); + BinaryArithmetic::Param sub_op_params{BinaryArithmetic::ArithmeticType::SUB, Activation::NONE}; + create<BinaryArithmetic>(graph, OIS{mul2_out_idx, fc2_out_idx}, OIS{sub_out_idx}, sub_op_params); + + graph->verify(); + return graph; +} + +// +// Tests setup/teardown +// + +// SetUp/TearDown methods runs before/after each test and performs actions common for each test +class HESchedulerTest : public ::testing::Test +{ +protected: + void SetUp() override + { + // Initialize mock backends + _cpu_backend = new MockBackendCPU(); + _gpu_backend = new MockBackendGPU(); + _npu_backend = new MockBackendNPU(); + _mock_backends = {_cpu_backend, _gpu_backend, _npu_backend}; + + // Remove previous profile data if it exists + if (!remove("exec_time.json")) + { + // DO NOTHING (no profile data) + } + + // Remember original value of 'EXECUTOR' environment variable + char *executor = std::getenv("EXECUTOR"); + _original_executor = executor == nullptr ? "" : executor; + + // Remember original value of 'PROFILING_MODE' environment variable + char *profiling_mode = std::getenv("PROFILING_MODE"); + _original_profiling_mode = profiling_mode == nullptr ? "" : profiling_mode; + } + + void TearDown() override + { + delete _cpu_backend; + delete _gpu_backend; + delete _npu_backend; + EXPECT_EQ(remove("exec_time.json"), 0); + setenv("EXECUTOR", _original_executor.c_str(), true); + setenv("PROFILING_MODE", _original_profiling_mode.c_str(), true); + } + + const MockBackendCPU *_cpu_backend{nullptr}; + const MockBackendGPU *_gpu_backend{nullptr}; + const MockBackendNPU *_npu_backend{nullptr}; + std::vector<const Backend *> _mock_backends; + + std::string _original_executor; + std::string _original_profiling_mode; +}; + +// +// HEScheduler tests +// + +class HESchedulerTestWithExecutorParam : public HESchedulerTest, + public testing::WithParamInterface<std::string> +{ +}; + +// SchedulerTestWithExecutorParam tests are parameterized with executor name and runs three times - +// one time for each executor +INSTANTIATE_TEST_SUITE_P(AllExecutors, HESchedulerTestWithExecutorParam, + testing::Values(LINEAR, DATAFLOW, PARALLEL)); + +// Test scheduler behavior for straight graph with known execution time of all nodes and permutes. +TEST_P(HESchedulerTestWithExecutorParam, straight_graph_known_exec_time) +{ + setExecutor(GetParam()); + + // Prepare graph + ir::Model model; + auto graph(createStraightGraph()); + model.push(ir::SubgraphIndex{0}, graph); + OperationIndex add_op_idx(0), sub_op_idx(1), mul_op_idx(2); + + // Set default execution and transfer time + setPermutationsExecutionTime(_mock_backends, OPERAND_SIZE, 1); + setOperationsExecutionTime(_mock_backends, {"Add", "Sub", "Mul"}, + {OPERATION_SIZE, OPERATION_SIZE, OPERATION_SIZE}, 1e4); + + // Test 1 + // Expected behaviour: scheduler assigns different backend to each node + { + // For each backend reduce execution time of one node + ExecTime et(_mock_backends); + setOperationExecTime(et, _cpu_backend, "Add", false, OPERATION_SIZE, 1); + setOperationExecTime(et, _gpu_backend, "Sub", false, OPERATION_SIZE, 1); + setOperationExecTime(et, _npu_backend, "Mul", false, OPERATION_SIZE, 1); + et.storeOperationsExecTime(); + + // Test scheduler + auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig(); + auto scheduler = compiler::HEScheduler(_mock_backends, coptions); + const auto br = scheduler.schedule(*graph); + ASSERT_EQ(br->getBackend(add_op_idx)->config()->id(), "cpu"); + ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "gpu"); + ASSERT_EQ(br->getBackend(mul_op_idx)->config()->id(), "npu"); + } + + // Test 2 + // Expected behaviour: scheduler assigns single backend to all nodes because of big transfer time + { + // Increase transfer time + setPermutationsExecutionTime(_mock_backends, OPERAND_SIZE, 1e5); + + // Test scheduler + auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig(); + auto scheduler = compiler::HEScheduler(_mock_backends, coptions); + const auto br = scheduler.schedule(*graph); + ASSERT_EQ(br->getBackend(add_op_idx)->config()->id(), "cpu"); + ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "cpu"); + ASSERT_EQ(br->getBackend(mul_op_idx)->config()->id(), "cpu"); + } +} + +// Test scheduler behavior for branched graph with known execution time of all nodes and permutes +TEST_P(HESchedulerTestWithExecutorParam, branched_graph_known_exec_time) +{ + const int64_t NPU_ET = 5000; + setExecutor(GetParam()); + + // Prepare graph + ir::Model model; + auto graph(createBranchedGraph()); + model.push(ir::SubgraphIndex{0}, graph); + OperationIndex add_op_idx(0), mul1_op_idx(1), mul2_op_idx(2), fc1_op_idx(3), fc2_op_idx(4), + sub_op_idx(5); + + // Set default execution and transfer time + setPermutationsExecutionTime(_mock_backends, OPERAND_SIZE, 1000); + setOperationsExecutionTime(_mock_backends, {"Add", "Sub", "Mul", "FullyConnected"}, + {OPERATION_SIZE, OPERATION_SIZE, OPERATION_SIZE, OPERATION_SIZE}, 1e4); + + // Test 1 + // Expected behaviour: for dataflow and linear executors scheduler assigns fastest backend to all + // nodes, in case of parallel executor scheduler assigns different backends to branches. + { + // Reduce execution time + ExecTime et(_mock_backends); + setOperationExecTime(et, _npu_backend, "Add", false, OPERATION_SIZE, NPU_ET); + setOperationExecTime(et, _npu_backend, "Mul", false, OPERATION_SIZE, NPU_ET); + setOperationExecTime(et, _npu_backend, "Sub", false, OPERATION_SIZE, NPU_ET); + setOperationExecTime(et, _npu_backend, "FullyConnected", false, OPERATION_SIZE, NPU_ET); + setOperationExecTime(et, _gpu_backend, "Mul", false, OPERATION_SIZE, NPU_ET + 1000); + setOperationExecTime(et, _gpu_backend, "FullyConnected", false, OPERATION_SIZE, NPU_ET + 1000); + et.storeOperationsExecTime(); + + // Test scheduler + auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig(); + auto scheduler = compiler::HEScheduler(_mock_backends, coptions); + const auto br = scheduler.schedule(*graph); + + std::string branch1_expected_backend("npu"), branch2_expected_backend("npu"); + if (GetParam() == PARALLEL) + { + branch1_expected_backend = + br->getBackend(mul1_op_idx)->config()->id() == "npu" ? "npu" : "gpu"; + branch2_expected_backend = branch1_expected_backend == "npu" ? "gpu" : "npu"; + } + + ASSERT_EQ(br->getBackend(add_op_idx)->config()->id(), "npu"); + ASSERT_EQ(br->getBackend(mul1_op_idx)->config()->id(), branch1_expected_backend); + ASSERT_EQ(br->getBackend(mul2_op_idx)->config()->id(), branch1_expected_backend); + ASSERT_EQ(br->getBackend(fc1_op_idx)->config()->id(), branch2_expected_backend); + ASSERT_EQ(br->getBackend(fc2_op_idx)->config()->id(), branch2_expected_backend); + ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "npu"); + } + + // Test 2 + // Expected behaviour: scheduler assigns single backend to all nodes + { + // Increase execution time for GPU backend + ExecTime et(_mock_backends); + /* for parallel executor: set a time, that is larger than sum_of_other_branches_nodes_cnt * + * npu_exec_time so that npu is prefered: the ith branch will wait for npu until it finishes the + * [0;i-1] branches nodes in DFS order. In each branch it goes deep intul doesn't encounter + * branching or scheduler assigns another backend to a node*/ + setOperationExecTime(et, _gpu_backend, "Mul", false, OPERATION_SIZE, NPU_ET * 3 + 1); + setOperationExecTime(et, _gpu_backend, "FullyConnected", false, OPERATION_SIZE, NPU_ET * 3 + 1); + et.storeOperationsExecTime(); + + // Test scheduler + auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig(); + auto scheduler = compiler::HEScheduler(_mock_backends, coptions); + const auto br = scheduler.schedule(*graph); + ASSERT_EQ(br->getBackend(add_op_idx)->config()->id(), "npu"); + ASSERT_EQ(br->getBackend(mul1_op_idx)->config()->id(), "npu"); + ASSERT_EQ(br->getBackend(mul2_op_idx)->config()->id(), "npu"); + ASSERT_EQ(br->getBackend(fc1_op_idx)->config()->id(), "npu"); + ASSERT_EQ(br->getBackend(fc2_op_idx)->config()->id(), "npu"); + ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "npu"); + } +} + +// Test scheduler behavior for branched graph and enabled profiling mode +TEST_F(HESchedulerTest, branched_graph_profiling_mode) +{ + const int ET = 1e5; + + // Turn on profiling mode + setProfilingMode(true); + setExecutor(DATAFLOW); + + // Prepare graph + ir::Model model; + auto graph(createBranchedGraph()); + model.push(ir::SubgraphIndex{0}, graph); + OperationIndex add_op_idx(0), mul1_op_idx(1), mul2_op_idx(2), fc1_op_idx(3), fc2_op_idx(4), + sub_op_idx(5); + + // Test 1 + // Expected behaviour: scheduler assigns backends to nodes with unknown execution time + { + // Set execution time for all backends/nodes except for cpu/Sub, npu/Mul, gpu/FC + ExecTime et(_mock_backends); + setOperationExecTime(et, _cpu_backend, "Add", false, OPERATION_SIZE, ET); + setOperationExecTime(et, _cpu_backend, "Mul", false, OPERATION_SIZE, ET + 1); + setOperationExecTime(et, _cpu_backend, "FullyConnected", false, OPERATION_SIZE, ET); + setOperationExecTime(et, _npu_backend, "Add", false, OPERATION_SIZE, ET); + setOperationExecTime(et, _npu_backend, "FullyConnected", false, OPERATION_SIZE, ET); + setOperationExecTime(et, _npu_backend, "Sub", false, OPERATION_SIZE, ET); + setOperationExecTime(et, _gpu_backend, "Add", false, OPERATION_SIZE, ET); + setOperationExecTime(et, _gpu_backend, "Mul", false, OPERATION_SIZE, ET + 1); + setOperationExecTime(et, _gpu_backend, "Sub", false, OPERATION_SIZE, ET); + et.storeOperationsExecTime(); + + // Test scheduler + auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig(); + auto scheduler = compiler::HEScheduler(_mock_backends, coptions); + const auto br = scheduler.schedule(*graph); + ASSERT_EQ(br->getBackend(mul1_op_idx)->config()->id(), "npu"); + ASSERT_EQ(br->getBackend(mul2_op_idx)->config()->id(), "npu"); + ASSERT_EQ(br->getBackend(fc1_op_idx)->config()->id(), "gpu"); + ASSERT_EQ(br->getBackend(fc2_op_idx)->config()->id(), "gpu"); + ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "cpu"); + } + + // Test 2 + // Expected behaviour: scheduler shuffling backends, so different backends are assigned to + // neighbor nodes + { + // Set execution time for rest backends/nodes (cpu/Sub, npu/Mul, gpu/FC) + ExecTime et(_mock_backends); + setOperationExecTime(et, _cpu_backend, "Sub", false, OPERATION_SIZE, ET); + setOperationExecTime(et, _npu_backend, "Mul", false, OPERATION_SIZE, ET + 1); + setOperationExecTime(et, _gpu_backend, "FullyConnected", false, OPERATION_SIZE, ET); + et.storeOperationsExecTime(); + + // Test scheduler + auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig(); + auto scheduler = compiler::HEScheduler(_mock_backends, coptions); + const auto br = scheduler.schedule(*graph); + ASSERT_NE(br->getBackend(add_op_idx)->config()->id(), + br->getBackend(mul1_op_idx)->config()->id()); + ASSERT_NE(br->getBackend(add_op_idx)->config()->id(), + br->getBackend(fc1_op_idx)->config()->id()); + ASSERT_NE(br->getBackend(mul1_op_idx)->config()->id(), + br->getBackend(mul2_op_idx)->config()->id()); + ASSERT_NE(br->getBackend(fc1_op_idx)->config()->id(), + br->getBackend(fc2_op_idx)->config()->id()); + ASSERT_NE(br->getBackend(mul2_op_idx)->config()->id(), + br->getBackend(sub_op_idx)->config()->id()); + ASSERT_NE(br->getBackend(fc2_op_idx)->config()->id(), + br->getBackend(sub_op_idx)->config()->id()); + } +} + +// TODO: Add tests with unknown execution and permutation time + +} // unnamed namespace diff --git a/runtime/onert/core/src/compiler/Linear.cc b/runtime/onert/core/src/compiler/Linear.cc index 49a989500..663cf5450 100644 --- a/runtime/onert/core/src/compiler/Linear.cc +++ b/runtime/onert/core/src/compiler/Linear.cc @@ -14,207 +14,38 @@ * limitations under the License. */ -#include <algorithm> - #include "Linear.h" -#include "backend/IConfig.h" -#include "backend/IConstantInitializer.h" -#include "backend/ITensorRegister.h" -#include "backend/Backend.h" +#include "../dumper/text/GraphDumper.h" + #include "util/logging.h" +#include <sstream> + namespace onert { namespace compiler { -std::vector<ir::OpSequenceIndex> Linear::linearize(const compiler::LoweredGraph &lowered_graph) +// TODO(easy) Change the LoweredGraph param to Graph +std::vector<ir::OperationIndex> Linear::linearize(const compiler::ILoweredGraph &lowered_graph) { - std::vector<ir::OpSequenceIndex> order; - lowered_graph.iterateTopolOpSeqs( - [&](const ir::OpSequenceIndex &index, const ir::OpSequence &) -> void { - order.emplace_back(index); - }); - return order; + return lowered_graph.graph().topolSortOperations(); } -void Linear::dump(const compiler::LoweredGraph &lowered_graph, - const std::vector<ir::OpSequenceIndex> &order) +// TODO(easy) Change the LoweredGraph param to Graph +void Linear::dump(const compiler::ILoweredGraph &lowered_graph, + const std::vector<ir::OperationIndex> &order) { + for (const auto &ind : order) { - const auto &toString = [](const onert::backend::Backend *backend) { - assert(backend); - std::string str; - str += backend->config()->id(); - return "{" + str + "}"; - }; - - VERBOSE(Linear) << "Final OpSequence" << std::endl; - for (const auto index : order) - { - const auto &op_seq = lowered_graph.op_seqs().at(index); - const auto lower_info = lowered_graph.getLowerInfo(index); - const auto &operations = lowered_graph.graph().operations(); - VERBOSE(Linear) << "* OP_SEQ " << toString(lower_info->backend()) << " " - << ir::getStrFromOpSeq(op_seq, operations) << std::endl; - } + // TODO Could logging system can handle this? (Inserting prefix for each line) + std::istringstream iss{dumper::text::formatOperation(lowered_graph.graph(), ind)}; + std::string line; + while (std::getline(iss, line)) + VERBOSE(Linearize) << line << std::endl; } } -void Linear::planTensors(const compiler::LoweredGraph &lowered_graph, - const std::vector<ir::OpSequenceIndex> &order) -{ - const auto &graph = lowered_graph.graph(); - ir::OperandIndexMap<std::shared_ptr<backend::ITensorBuilder>> tensor_builder_map; - - ir::OperandIndexMap<uint32_t> uses_map; - ir::OperandIndexMap<uint32_t> def_map; - ir::OperandIndexSequence constants; - - // Prepare scanning - graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) { - const auto lower_info = lowered_graph.getLowerInfo(ind); - // TODO Remove if onert doesn't support anymore such as - // GeneratedTests.reshape_quant8_weights_as_inputs - if (lower_info->def_factors().size() == 0 && lower_info->use_factors().size() == 0 && - !graph.getInputs().contains(ind)) - { - VERBOSE(LINEAR) << "Operand #" << ind.value() << " will not be used. no more process." - << std::endl; - return; - } - - // Unused input of subgraph - // TODO Register unused input as nullptr in tensor_builder - if (lower_info->def_factors().size() == 0 && lower_info->use_factors().size() == 0 && - graph.getInputs().contains(ind)) - { - VERBOSE(LINEAR) << "Operand #" << ind.value() << " will not be used. no more process." - << std::endl; - return; - } - - uses_map[ind] = obj.getUses().size(); - def_map[ind] = obj.getDef().valid() ? 1 : 0; - - bool is_const = obj.isConstant(); - if (is_const) - { - constants.append(ind); - } - - auto factor = lower_info->def_factors().getOnlyElement(); - auto backend = factor.backend(); - auto tensor_builder = lowered_graph.backend_contexts().at(backend)->tensor_builder; - if (!tensor_builder->isRegistered(ind)) - { - // These tensors do not exist in any op_seq (No use and def) - const auto info = obj.info(); - const auto backend_layout = factor.layout(); - // TODO Change tensor info to have permuted shape - tensor_builder->registerTensorInfo(ind, info, backend_layout); - } - - tensor_builder_map[ind] = tensor_builder; - }); - - // If a tensor is model output, increase the use of the tensor. - // This aim is same to above one. - for (const auto &ind : graph.getOutputs() | ir::Remove::DUPLICATED) - { - uses_map[ind]++; - } - - // Start scanning to do notify{First|Last}Use for each tensor - - // If a tensor is a constant, increase the use of the tensor. - // It makes the tensor not be dealloced. It means these will be deallocated last. - // And allocate constant operands first - VERBOSE(LINEAR) << "TENSORS as CONSTANT" << std::endl; - for (const auto &ind : constants) - { - uses_map[ind]++; - tensor_builder_map[ind]->notifyFirstUse(ind); - } - - // Allocate Model's inputs - VERBOSE(LINEAR) << "TENSORS as MODEL INPUT" << std::endl; - for (const auto &ind : graph.getInputs() | ir::Remove::DUPLICATED) - { - auto tensor_builder = tensor_builder_map[ind]; - if (!tensor_builder) // for GeneratedTests.xxx_weights_as_inputs - continue; - tensor_builder->notifyFirstUse(ind); - } - - // At each operation, - // 1. Scan DEF of outputs. If the DEF, allocate it - // 2. Scan USE of inputs. Decrease the USE and deallocate if the USE is 0 - VERBOSE(LINEAR) << "TENSORS" << std::endl; - for (const auto op_seq_ind : order) - { - const auto &op_seq = lowered_graph.op_seqs().at(op_seq_ind); - for (const auto &op_idx : op_seq.operations()) - { - for (const auto &ind : graph.operations().at(op_idx).getOutputs() | ir::Remove::DUPLICATED | - ir::Remove::UNDEFINED) - { - assert(def_map.find(ind) != def_map.end()); - if (def_map[ind]) - { - def_map[ind] = 0; - tensor_builder_map[ind]->notifyFirstUse(ind); - } - } - - for (const auto &ind : graph.operations().at(op_idx).getInputs() | ir::Remove::DUPLICATED | - ir::Remove::UNDEFINED) - { - assert(uses_map.find(ind) != uses_map.end()); - assert(uses_map[ind] > 0); - uses_map[ind]--; - if (uses_map[ind] == 0) - { - // plan for deallocation of static tensornode - tensor_builder_map[ind]->notifyLastUse(ind); - - // plan for deallocation of dynamic tensor - auto dyn_tensor_manager = tensor_builder_map[ind]->dynamicTensorManager(); - if (dyn_tensor_manager) - dyn_tensor_manager->planDealloc(op_idx, ind); - } - } - } - } - - // Dispose and validate - for (const auto &ind : graph.getOutputs() | ir::Remove::DUPLICATED) - { - --uses_map[ind]; - if (uses_map[ind] == 0) // To prevent notifyLastUse from being called twice - { - tensor_builder_map[ind]->notifyLastUse(ind); - } - } - - for (const auto &ind : constants) - { - --uses_map[ind]; - if (uses_map[ind] == 0) // To prevent notifyLastUse from being called twice - { - tensor_builder_map[ind]->notifyLastUse(ind); - } - } - - assert( - std::all_of(uses_map.begin(), uses_map.end(), - [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; })); - - assert( - std::all_of(def_map.begin(), def_map.end(), - [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; })); -} - } // namespace compiler } // namespace onert diff --git a/runtime/onert/core/src/compiler/Linear.h b/runtime/onert/core/src/compiler/Linear.h index 1e24cf92b..4f92dc88d 100644 --- a/runtime/onert/core/src/compiler/Linear.h +++ b/runtime/onert/core/src/compiler/Linear.h @@ -20,18 +20,8 @@ #include <vector> #include <memory> -#include "ir/OpSequences.h" #include "ir/Index.h" -#include "backend/ITensorBuilder.h" -#include "compiler/LoweredGraph.h" - -namespace onert -{ -namespace ir -{ -struct OperationVisitor; -} // namespace ir -} // namespace onert +#include "compiler/ILoweredGraph.h" namespace onert { @@ -41,11 +31,9 @@ namespace compiler class Linear { public: - static std::vector<ir::OpSequenceIndex> linearize(const compiler::LoweredGraph &lowered_graph); - static void dump(const compiler::LoweredGraph &lowered_graph, - const std::vector<ir::OpSequenceIndex> &order); - static void planTensors(const compiler::LoweredGraph &lowered_graph, - const std::vector<ir::OpSequenceIndex> &order); + static std::vector<ir::OperationIndex> linearize(const compiler::ILoweredGraph &lowered_graph); + static void dump(const compiler::ILoweredGraph &lowered_graph, + const std::vector<ir::OperationIndex> &order); }; } // namespace compiler diff --git a/runtime/onert/core/src/compiler/LoweredGraph.cc b/runtime/onert/core/src/compiler/LoweredGraph.cc index 1489a1884..46a45e44a 100644 --- a/runtime/onert/core/src/compiler/LoweredGraph.cc +++ b/runtime/onert/core/src/compiler/LoweredGraph.cc @@ -16,21 +16,23 @@ #include "compiler/LoweredGraph.h" -#include <assert.h> -#include <sstream> -#include "util/logging.h" -#include "compiler/pass/ConstantInsertionPass.h" -#include "compiler/pass/ConstantLoweringPass.h" -#include "compiler/pass/PermutationOperationPass.h" -#include "compiler/pass/PermutationInsertionPass.h" -#include "compiler/pass/PermutationEliminationPass.h" -#include "ir/GraphIterator.h" -#include "ir/verifier/Verifier.h" +#include "HEScheduler.h" +#include "ManualScheduler.h" +#include "pass/ConstantInsertionPass.h" +#include "pass/ConstantLoweringPass.h" +#include "pass/PassRunner.h" +#include "pass/PermutationEliminationPass.h" +#include "pass/PermutationInsertionPass.h" +#include "pass/PermutationOperationPass.h" +#include "../dumper/text/GraphDumper.h" +#include "../ir/verifier/Verifier.h" + #include "backend/Backend.h" -#include "backend/IConfig.h" #include "compiler/BackendResolver.h" -#include "compiler/ManualScheduler.h" -#include "compiler/HEScheduler.h" +#include "util/logging.h" + +#include <cassert> +#include <sstream> namespace onert { @@ -39,18 +41,15 @@ namespace compiler LoweredGraph::LoweredGraph(const ir::Graph &graph, const CompilerOptions &options) : _graph{graph} { - bool linear_executor = (options.executor == "Linear"); + lowerGraph(options); +} +void LoweredGraph::lowerGraph(const CompilerOptions &options) +{ // Build backend contexts auto &backend_manager = BackendManager::get(); - - // Always create Controlflow backend context - auto cf_backend = backend_manager.getControlflow(); - _backend_contexts.emplace( - cf_backend, cf_backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor)); - // Create contexts for other backends - for (auto backend_str : options.backend_list) + for (auto &&backend_str : options.backend_list) { backend_manager.loadBackend(backend_str); auto backend = backend_manager.get(backend_str); @@ -60,12 +59,9 @@ LoweredGraph::LoweredGraph(const ir::Graph &graph, const CompilerOptions &option // we should change it back(throw if backend is not loaded) later. if (!backend) { - VERBOSE(LoweredGraph) << "Cannot load backend - " << backend_str; + VERBOSE(LoweredGraph) << "Cannot load backend - " << backend_str << std::endl; continue; } - - _backend_contexts.emplace( - backend, backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor)); } if (backend_manager.num_backends() == 0) throw std::runtime_error{"No available backends loaded."}; @@ -73,317 +69,115 @@ LoweredGraph::LoweredGraph(const ir::Graph &graph, const CompilerOptions &option // TODO Move "schedule" phase out of here // Schedule std::unique_ptr<BackendResolver> backend_resolver; + auto all_backends = backend_manager.getAll(); if (options.he_scheduler) { - auto scheduler = HEScheduler(_backend_contexts, options); + auto scheduler = HEScheduler(all_backends, options); backend_resolver = scheduler.schedule(_graph); _indexed_ranks = scheduler.getIndexedRanks(); } else { - auto scheduler = ManualScheduler(_backend_contexts, options); + auto scheduler = ManualScheduler(all_backends, options); backend_resolver = scheduler.schedule(_graph); } - { - // operand::LowerInfo holder - ir::OperandIndexMap<std::unique_ptr<ir::operand::LowerInfo>> operands_lower_info; - - _graph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) { - operands_lower_info[index] = std::make_unique<ir::operand::LowerInfo>(); - }); - - // Make op_seqs while checking whether a node can be merged into a op_seq. - makeOpSequences(operands_lower_info, options, *backend_resolver); + makeLowerInfo(*backend_resolver); + VERBOSE(LoweredGraph) << "dump before mandatory passes" << std::endl; + dumper::text::dumpLoweredGraph(*this); - _op_seqs.iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) { - assert(op_seq.operations().size() > 0); - std::reverse(std::begin(op_seq.operations()), std::end(op_seq.operations())); - }); + // Mandatory passes - kind of legalization(?) + pass::PassRunner{} + .append(std::make_unique<pass::ConstantInsertionPass>(*this)) + .append(std::make_unique<pass::ConstantLoweringPass>(*this)) + .append(std::make_unique<pass::PermutationOperationPass>(*this)) + .append(std::make_unique<pass::PermutationInsertionPass>(*this)) + .run(); - VERBOSE(OpSequences) << "dump without permutation" << std::endl; - dumpOpSequences(_op_seqs, _graph.operations()); + dumpLowerInfo(); - pass::ConstantInsertionPass ci_pass(*this); - ci_pass.run(); + // Optimization passes (optional) + pass::PassRunner{}.append(std::make_unique<pass::PermutationEliminationPass>(*this)).run(); - pass::ConstantLoweringPass cl_pass(*this); - cl_pass.run(); - - // Set LowerInfo for each operand from the operand::LowerInfo holder - manipulateLowerInfo(operands_lower_info, options.is_primary_subgraph); - - dumpLowerInfo(); - } - - // Run Permutation Passes - { - pass::PermutationOperationPass po_pass(*this); - po_pass.run(); - - pass::PermutationInsertionPass pi_pass(*this); - pi_pass.run(); - - pass::PermutationEliminationPass pe_pass(*this); - pe_pass.run(); - - VERBOSE(OpSequences) << "dump with permutation" << std::endl; - dumpOpSequences(_op_seqs, _graph.operations()); - } + VERBOSE(LoweredGraph) << "Dump after all the passes" << std::endl; + for (auto &&operand : _graph.getInputs()) + VERBOSE(LoweredGraph) << "Graph Input : " << operand << std::endl; + for (auto &&operand : _graph.getOutputs()) + VERBOSE(LoweredGraph) << "Graph Output : " << operand << std::endl; + dumper::text::dumpLoweredGraph(*this); // Graph verifications { + assert(ir::verifier::InputOutputChecker().verify(_graph)); assert(ir::verifier::DAGChecker().verify(_graph)); - assert(ir::verifier::EdgeConsistencyChecker().verify(_graph)); + assert(ir::verifier::EdgeChecker().verify(_graph)); } } -const ir::operation::LowerInfo * -LoweredGraph::getLowerInfo(const ir::OpSequenceIndex &op_seq_index) const +void LoweredGraph::makeLowerInfo(const compiler::BackendResolver &backend_resolver) { - auto itr = _lower_info_map.op_seq.find(op_seq_index); - if (itr == _lower_info_map.op_seq.end()) - return nullptr; - return itr->second.get(); -} - -void LoweredGraph::setLowerInfo(const ir::OpSequenceIndex &op_seq_index, - std::unique_ptr<ir::operation::LowerInfo> &&lower_info) -{ - _lower_info_map.op_seq.insert(std::make_pair(op_seq_index, std::move(lower_info))); -} + _graph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) { + lower_info().operand.set(index, std::make_unique<OperandLowerInfo>()); + }); -void LoweredGraph::removeLowerInfo(const ir::OpSequenceIndex &op_seq_index) -{ - auto &op_seq_lower_info = _lower_info_map.op_seq; - assert(op_seq_lower_info.find(op_seq_index) != op_seq_lower_info.end()); - for (auto it = op_seq_lower_info.begin(); it != op_seq_lower_info.end(); ++it) - { - if (it->first == op_seq_index) + // Set operand lower info using assigned backends to operations + _graph.operations().iterate([&](const ir::OperationIndex &op_ind, const ir::IOperation &) { + const ir::IOperation &op = _graph.operations().at(op_ind); + auto backend = backend_resolver.getBackend(op_ind); + if (!backend) { - op_seq_lower_info.erase(it); - break; + throw std::runtime_error{"Fail to find backend for " + op.name() + " operation"}; } - } -} - -const ir::operand::LowerInfo *LoweredGraph::getLowerInfo(const ir::OperandIndex &index) const -{ - auto itr = _lower_info_map.operand.find(index); - if (itr == _lower_info_map.operand.end()) - return nullptr; - return itr->second.get(); -} - -ir::operand::LowerInfo *LoweredGraph::getLowerInfo(const ir::OperandIndex &index) -{ - auto itr = _lower_info_map.operand.find(index); - if (itr == _lower_info_map.operand.end()) - return nullptr; - return itr->second.get(); -} - -void LoweredGraph::setLowerInfo(const ir::OperandIndex &index, - std::unique_ptr<ir::operand::LowerInfo> &&lower_info) -{ - _lower_info_map.operand.insert(std::make_pair(index, std::move(lower_info))); -} - -void LoweredGraph::removeLowerInfo(const ir::OperandIndex &index) -{ - _lower_info_map.operand.erase(index); -} - -void LoweredGraph::iterateTopolOpSeqs( - const std::function<void(const ir::OpSequenceIndex &, const ir::OpSequence &)> &fn) const -{ - // Topological Sorting for ir::OpSequences - std::vector<ir::OpSequenceIndex> topol_sorted; - ir::PostDfsIterator<true>{}.iterateOpSeqs( - *this, [&](const ir::OpSequenceIndex &index, const ir::OpSequence &) { - topol_sorted.emplace_back(index); - }); - std::reverse(topol_sorted.begin(), topol_sorted.end()); - for (const auto op_seq_idx : topol_sorted) - { - const auto &op_seq = _op_seqs.at(op_seq_idx); - fn(op_seq_idx, op_seq); - } -} - -void LoweredGraph::iterateTopolOpSeqs( - const std::function<void(const ir::OpSequenceIndex &, ir::OpSequence &)> &fn) -{ - // Topological Sorting for ir::OpSequences - std::vector<ir::OpSequenceIndex> topol_sorted; - ir::PostDfsIterator<false>{}.iterateOpSeqs( - *this, [&](const ir::OpSequenceIndex &index, ir::OpSequence &) { - topol_sorted.emplace_back(index); - }); - std::reverse(topol_sorted.begin(), topol_sorted.end()); - for (const auto op_seq_idx : topol_sorted) - { - auto &op_seq = _op_seqs.at(op_seq_idx); - fn(op_seq_idx, op_seq); - } -} - -ir::OpSequenceIndex LoweredGraph::appendFreshSingleOpSequence(const ir::OperationIndex &node_index, - const ir::Operation &node) -{ - // Create a fresh op_seq with one operation, and append it to op_seqs - // Create a fresh op_seq - auto op_seq = std::make_unique<ir::OpSequence>(_graph.layout()); - - // Add an operation - op_seq->appendOperation(node_index); - - // Update input/output - op_seq->setOutputs(node.getOutputs()); - op_seq->setInputs(node.getInputs()); - - return _op_seqs.emplace(std::move(op_seq)); -} - -void LoweredGraph::makeOpSequences( - ir::OperandIndexMap<std::unique_ptr<ir::operand::LowerInfo>> &operands_lower_info, - const CompilerOptions &options, const BackendResolver &backend_resolver) -{ - // if SUBG_MAX_NODE == 0, no limit on nodes of a op_seq - const int op_seq_max_node = options.op_seq_max_node; - assert(op_seq_max_node >= 0); - - bool is_profiling = options.he_profiling_mode; - ir::OpSequence *op_seq = nullptr; - ir::OpSequenceIndex op_seq_index; - - // NOTE: The below method appends nodes while making one op_seq if needed. If something better - // ways, happy to update this code. - ir::PostDfsConstIterator{}.iterate( - _graph, [&](const ir::OperationIndex &node_index, const ir::Operation &node) { - // LowerInfo for in/output operands - auto backend = backend_resolver.getBackend(node_index); - - // Get frontend's layout - auto frontend_layout = _graph.layout(); - - // The layout of each backend should be set at another place - // TODO Change setting layout of each backend at another place - auto backend_layout = backend->config()->supportLayout(node, frontend_layout); - - for (auto operand : node.getInputs() | ir::Remove::UNDEFINED) - { - auto &&lower_info = operands_lower_info.at(operand); - lower_info->addUsePermuteFactor(ir::operand::PermuteFactor{backend, backend_layout}); - } - for (auto operand : node.getOutputs()) - { - auto &&lower_info = operands_lower_info.at(operand); - lower_info->addDefPermuteFactor(ir::operand::PermuteFactor{backend, backend_layout}); - } - - bool new_op_seq = (op_seq == nullptr || - (op_seq_max_node != 0 && - op_seq->operations().size() >= static_cast<size_t>(op_seq_max_node))); - - // for profiling each op_seq must contain just one node, - // so that we can measure a node separately - if (new_op_seq || is_profiling || - !mergeable(op_seq_index, node_index, backend_layout, backend_resolver)) - { - auto new_op_seq_index = appendFreshSingleOpSequence(node_index, node); - - // ir::OpSequence LowerInfo - setLowerInfo(new_op_seq_index, - std::make_unique<ir::operation::LowerInfo>(backend, backend_layout)); - - op_seq_index = new_op_seq_index; - op_seq = &(_op_seqs.at(new_op_seq_index)); - - VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is created for " - << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl; - } - else - { - op_seq->appendOperation(node_index); - // Set inputs - auto new_inputs = node.getInputs(); - // Add inputs except outputs of the previous node - for (auto ind : op_seq->getInputs()) - { - if (!node.getOutputs().contains(ind)) - new_inputs.append(ind); - } - op_seq->setInputs(new_inputs); - VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " merges " - << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl; - } - }); -} + auto frontend_layout = _graph.layout(); -void LoweredGraph::manipulateLowerInfo( - ir::OperandIndexMap<std::unique_ptr<ir::operand::LowerInfo>> &operands_lower_info, - bool is_primary) -{ - const auto controlflow_backend = BackendManager::get().getControlflow(); + // The layout of each backend should be set at another place + // TODO Change setting layout of each backend at another place + auto backend_layout = backend->config()->supportLayout(op, frontend_layout); - // TODO Rather than handling primary graph specially, - // let the permute inserted and remove it later - if (is_primary) - { - // TODO Rather than using NHWC Get frontend layout of this node from IR - auto factor = ir::operand::PermuteFactor{controlflow_backend, ir::Layout::NHWC}; - for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED) + for (auto &&ind : op.getInputs() | ir::Remove::UNDEFINED) { - auto &&lower_info = operands_lower_info.at(index); - assert(lower_info->def_factors().empty()); - lower_info->addDefPermuteFactor(factor); + auto &operand_li = lower_info().operand.at(ind); + operand_li.addUsePermuteFactor(PermuteFactor{backend, backend_layout}); } - for (auto index : _graph.getOutputs()) + for (auto &&ind : op.getOutputs() | ir::Remove::UNDEFINED) { - auto &&lower_info = operands_lower_info.at(index); - lower_info->addUsePermuteFactor(factor); + auto &operand_li = lower_info().operand.at(ind); + operand_li.addDefPermuteFactor(PermuteFactor{backend, backend_layout}); } - } - else + lower_info().operation.set( + op_ind, std::make_unique<compiler::OperationLowerInfo>(backend, backend_layout)); + }); + + // Handle graph inputs and outputs + const auto builtin_backend = BackendManager::get().getBuiltin(); + auto factor = PermuteFactor{builtin_backend, _graph.layout()}; + for (auto &&index : _graph.getInputs() | ir::Remove::UNDEFINED) { - for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED) - { - auto &&lower_info = operands_lower_info.at(index); - if (!(lower_info->def_factors().size() == 0 && lower_info->use_factors().size() == 0)) - { - // In case of not that Graph's input is not used in any operation and not the graph's - // output. - // In other words, it is not unused input in Graph. - lower_info->addDefPermuteFactor(*lower_info->use_factors().begin()); - } - else - { - // In case of that an operand is Graph's input and not input or output of any operation - lower_info->addDefPermuteFactor(ir::operand::PermuteFactor{ - controlflow_backend, - ir::Layout::NHWC // TODO Get frontend layout of this node from IR - }); - } - } + auto &operand_li = lower_info().operand.at(index); + assert(operand_li.def_factors().empty()); + operand_li.addDefPermuteFactor(factor); } - for (auto index : _graph.getOutputs()) + for (auto &&index : _graph.getOutputs() | ir::Remove::UNDEFINED) { - auto &&lower_info = operands_lower_info.at(index); - if (lower_info->def_factors().size() == 0) - { - // In case of that an operand is Graph's output and not input or output of any operation - lower_info->addDefPermuteFactor(ir::operand::PermuteFactor{ - controlflow_backend, - ir::Layout::NHWC // TODO Get frontend layout of this node from IR - }); - } + auto &operand_li = lower_info().operand.at(index); + operand_li.addUsePermuteFactor(factor); } - // Set LowerInfo for each operand from the operand::LowerInfo holder - _graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &) { - setLowerInfo(index, std::move(operands_lower_info[index])); + // Handle variable tensors + _graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &operand) { + // Some inputs of an operation could be non-constant, but not existed in graph inputs/outputs + // and not undefined operand - these are variable tensors. For example, + // UnidirectionalSequenceLSTM has such inputs. + if (operand.info().isVariable()) + { + // The variable operand with buffer is not supported yet + assert(operand.data() == nullptr); + assert(operand.getUses().size() == 1 && !operand.getDef().valid()); + auto operand_li = lower_info().operand.at(index); + assert(operand_li.def_factors().empty()); + operand_li.addDefPermuteFactor(operand_li.use_factors().getOnlyElement()); + } }); } @@ -395,12 +189,22 @@ void LoweredGraph::dumpLowerInfo() std::map<uint32_t, std::string> dumps; _graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &object) { - std::stringstream sstream; - if (!getLowerInfo(index)->def_factors().empty() || !getLowerInfo(index)->use_factors().empty()) + const auto operand_lower_info = lower_info().operand.getRawPtr(index); + assert(operand_lower_info); + if (!operand_lower_info->def_factors().empty() || !operand_lower_info->use_factors().empty()) { - auto factors_to_string = [](const ir::operand::PermuteFactorSet &factors) { + auto shape_to_string = [](const ir::Shape &shape) { + std::stringstream sstream; + sstream << "{ "; + for (auto i = 0; i < shape.rank(); ++i) + sstream << (shape.dim(i)) << " "; + sstream << "}"; + return sstream.str(); + }; + + auto factors_to_string = [](const PermuteFactorSet &factors) { std::string str; - for (auto factor : factors) + for (auto &&factor : factors) { str += factor.backend()->config()->id(); str += "(" + to_string(factor.layout()) + ")"; @@ -409,159 +213,45 @@ void LoweredGraph::dumpLowerInfo() return "{ " + str + "}"; }; - auto operation_index_to_string = [](const ir::OperationIndexSet &operations) { - std::string str; - for (auto op : operations) - { - str += std::to_string(op.value()); - str += " "; - } - return "{ " + str + "}"; + auto operation_index_set_to_string = [](const ir::OperationIndexSet &operations) { + std::stringstream sstream; + sstream << "{ "; + for (auto &&op : operations) + sstream << op << " "; + sstream << "}"; + return sstream.str(); + }; + + auto data_to_str = [](const ir::Data *data) { + return (data ? (std::to_string(data->size()) + " bytes") : "N/A"); }; - const auto lower_info = getLowerInfo(index); - const auto &shape = object.shape(); - std::string def_ops = - object.getDef().valid() ? std::to_string(object.getDef().value()) : "N/A"; - std::string use_ops = operation_index_to_string(object.getUses()); - std::string def_layouts = factors_to_string(lower_info->def_factors()); - std::string use_layouts = factors_to_string(lower_info->use_factors()); - sstream << "Operand #" << index.value() << " LowerInfo" << std::endl; - sstream << " - Shape : { "; - for (auto i = 0; i < shape.rank(); ++i) - { - sstream << (shape.dim(i)) << " "; - } - sstream << "}" << std::endl; - sstream << " - Def ir::Operations : " << def_ops << std::endl; - sstream << " - Use ir::Operations : " << use_ops << std::endl; - sstream << " - Lower Info" << std::endl; - sstream << " - Def Backends : " << def_layouts << std::endl; - sstream << " - Use Backends : " << use_layouts << std::endl; + std::string shape_str = shape_to_string(object.shape()); + std::string def_op = operation_index_set_to_string({object.getDef()}); + std::string use_ops = operation_index_set_to_string(object.getUses()); + std::string def_factors = factors_to_string(operand_lower_info->def_factors()); + std::string use_factors = factors_to_string(operand_lower_info->use_factors()); + std::stringstream sstream; + sstream << "Operand " << index << " Info" << std::endl; + sstream << " - Shape : " << shape_str << std::endl; + sstream << " - Def/Uses : Def " << def_op << " Uses " << use_ops << std::endl; + sstream << " - Data : " << data_to_str(object.data()) << std::endl; + sstream << " - LowerInfo : Def " << def_factors << " Uses " << use_factors << std::endl; + dumps.emplace(index.value(), sstream.str()); } - dumps.emplace(index.value(), sstream.str()); }); for (const auto &e : dumps) { if (!e.second.empty()) { - VERBOSE(Lower) << e.second; + std::istringstream iss(e.second); + std::string line; + while (std::getline(iss, line)) + VERBOSE(Lower) << line << std::endl; } } } -bool LoweredGraph::mergeable(const ir::OpSequenceIndex &op_seq_index, - const ir::OperationIndex &node_index, ir::Layout layout, - const BackendResolver &backend_resolver) -{ - // Are they mergeable? - // 1. the same backend id and layout? - // 2. Is op_seq or node branched? - // 3. if 1 is true, the op_seq and a node are connected? - const auto &op_seq = _op_seqs.at(op_seq_index); - const auto &node = _graph.operations().at(node_index); - - // The same backend id and layout? - { - const auto op_seq_backend_layout = getLowerInfo(op_seq_index)->layout(); - const auto &op_seq_backend_id = getLowerInfo(op_seq_index)->backend()->config()->id(); - const auto &node_backend_id = backend_resolver.getBackend(node_index)->config()->id(); - VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " { " << op_seq_backend_id << "(" - << to_string(op_seq_backend_layout) << ") } " - << " NODE#" << node_index.value() << " (" << node.name() << ") { " - << node_backend_id << "(" << to_string(layout) << ") } " << std::endl; - if (op_seq_backend_id != node_backend_id || op_seq_backend_layout != layout) - return false; - } - - // Branched? - { - std::unordered_set<ir::OperationIndex> branched_set; - - // Check for branching up - for (const auto &input : op_seq.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) - { - const auto &input_obj = _graph.operands().at(input); - auto def = input_obj.getDef(); - if (def.valid()) - { - branched_set.insert(def); - if (branched_set.size() > 1) - { - return false; - } - } - } - branched_set.clear(); - - // Check for branching down - for (const auto &output : node.getOutputs() | ir::Remove::DUPLICATED) - { - // TODO Fix this workaround for the case of model outputs that are used by another operation - // This is needed since the branching is decided by operation, but for model outputs, - // there is controlflow backen(use backend) but no actual use operation exists - if (_graph.getOutputs().contains(output)) - return false; - - const auto &output_obj = _graph.operands().at(output); - for (const auto &use : output_obj.getUses()) - { - branched_set.insert(use); - if (branched_set.size() > 1) - { - return false; - } - } - } - } - - // Connected? - // an input of one node is an output of the other node? or vice-versa? - { - const auto &node_inputs = node.getInputs(); - const auto &node_outputs = node.getOutputs(); - - // op_seq's operations are in order so that we just check the first and the last - std::vector<ir::OperationIndex> op_seq_ops{op_seq.operations()[0]}; - if (op_seq.operations().size() > 1) - op_seq_ops.emplace_back(op_seq.operations()[op_seq.operations().size() - 1]); - - for (const auto &n_index : op_seq_ops) - { - const auto &n = _graph.operations().at(n_index); - - // node's output == op_seq's input? - for (const auto input : n.getInputs() | ir::Remove::UNDEFINED) - { - if (node_outputs.contains(input)) - { - VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value() - << "(" << n.name() << ") is connected to NODE#" << node_index.value() - << "(" << node.name() << ")" << std::endl; - return true; - } - } - - // node's input == op_seq's output? - for (const auto output : n.getOutputs()) - { - if (node_inputs.contains(output)) - { - VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value() - << " (" << n.name() << ") is connected to NODE#" << node_index.value() - << std::endl; - return true; - } - } - } - - VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is not connected to NODE#" - << node_index.value() << "(" << node.name() << ")" << std::endl; - } - - return false; -} - } // namespace compiler } // namespace onert diff --git a/runtime/onert/core/src/compiler/ManualScheduler.cc b/runtime/onert/core/src/compiler/ManualScheduler.cc index ed49ee56f..ccd08893f 100644 --- a/runtime/onert/core/src/compiler/ManualScheduler.cc +++ b/runtime/onert/core/src/compiler/ManualScheduler.cc @@ -29,9 +29,9 @@ namespace onert namespace compiler { -ManualScheduler::ManualScheduler(const backend::BackendContexts &backend_contexts, +ManualScheduler::ManualScheduler(const std::vector<const backend::Backend *> &backends, const compiler::CompilerOptions &options) - : _backend_contexts{backend_contexts}, _options{options} + : _backends{backends}, _options{options} { } @@ -42,7 +42,7 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap // This fallback will be used in case that `backend_for_all` is unavailable auto fallback = [&]() -> const backend::Backend * { - for (auto backend_id : _options.backend_list) + for (auto &&backend_id : _options.backend_list) { auto backend = resolveBackend(backend_id); if (backend) @@ -58,20 +58,20 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap VERBOSE(ManualScheduler) << "Default backend for all ops: " << backend_all->config()->id() << std::endl; - graph.operations().iterate([&](const ir::OperationIndex &index, const ir::Operation &) { + graph.operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &) { backend_resolver->setBackend(index, backend_all); }); // 2. Backend per operation type std::unordered_map<ir::OpCode, backend::Backend *> op_type_map; - for (auto &pair : manual_options.opcode_to_backend) + for (const auto &pair : manual_options.opcode_to_backend) { op_type_map.emplace(pair.first, BackendManager::get().get(pair.second)); } // By default, Custom uses cpu backend op_type_map[ir::OpCode::Custom] = BackendManager::get().get("cpu"); - graph.operations().iterate([&](const ir::OperationIndex &index, const ir::Operation &operation) { + graph.operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &operation) { auto itr = op_type_map.find(operation.opcode()); if (itr != op_type_map.end()) { @@ -80,7 +80,7 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap }); // 3. Backend per operation - for (auto &pair : manual_options.index_to_backend) + for (const auto &pair : manual_options.index_to_backend) { const auto &key = pair.first; const auto &val = pair.second; @@ -88,22 +88,21 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap try { graph.operations().at(key); // Check if exist, or this will throw - backend_resolver->setBackend( - key, BackendManager::get().get( - val)); // TODO Ensure this backend is available in backend contexts + backend_resolver->setBackend(key, BackendManager::get().get(val)); } catch (...) { - VERBOSE(ManualScheduler) << "Invalid value while OperationIndex to Backend mapping : @" - << key.value() << " -> \"" << val << "\"" << std::endl; + VERBOSE(ManualScheduler) << "Invalid value while OperationIndex to Backend mapping : @" << key + << " -> \"" << val << "\"" << std::endl; } } // Dump final assignment - backend_resolver->iterate([&](const ir::OperationIndex &index, const backend::Backend &backend) { - VERBOSE(ManualScheduler) << "backend for operation #" << index.value() << ": " - << backend.config()->id() << std::endl; - }); + WHEN_LOG_ENABLED(backend_resolver->iterate( + [&](const ir::OperationIndex &index, const backend::Backend &backend) { + VERBOSE(ManualScheduler) << "backend for " << index << ": " << backend.config()->id() + << std::endl; + })); return backend_resolver; } @@ -113,7 +112,7 @@ const backend::Backend *ManualScheduler::resolveBackend(const std::string &id, { // Ensure if the backend is available in the current backend context const backend::Backend *backend = BackendManager::get().get(id); - if (!backend || _backend_contexts.find(backend) == _backend_contexts.end()) + if (!backend || std::find(_backends.begin(), _backends.end(), backend) == _backends.end()) { backend = fallback; } diff --git a/runtime/onert/core/src/compiler/ManualScheduler.h b/runtime/onert/core/src/compiler/ManualScheduler.h index 41503f7ff..18732d744 100644 --- a/runtime/onert/core/src/compiler/ManualScheduler.h +++ b/runtime/onert/core/src/compiler/ManualScheduler.h @@ -28,7 +28,7 @@ namespace compiler class ManualScheduler : public IScheduler { public: - ManualScheduler(const backend::BackendContexts &backend_contexts, + ManualScheduler(const std::vector<const backend::Backend *> &backends, const compiler::CompilerOptions &options); std::unique_ptr<BackendResolver> schedule(const ir::Graph &graph) override; @@ -37,7 +37,7 @@ private: const backend::Backend *fallback = nullptr); private: - const backend::BackendContexts &_backend_contexts; + std::vector<const backend::Backend *> _backends; compiler::CompilerOptions _options; }; diff --git a/runtime/onert/core/src/compiler/MultiModelCompiler.cc b/runtime/onert/core/src/compiler/MultiModelCompiler.cc new file mode 100644 index 000000000..7fdf700c7 --- /dev/null +++ b/runtime/onert/core/src/compiler/MultiModelCompiler.cc @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MultiModelCompiler.h" + +#include "CompilerHelpers.h" +#include "ExecutorFactory.h" +#include "ShapeValidator.h" +#include "pass/ConstantOutputPass.h" +#include "pass/OddOutputPass.h" +#include "pass/PassRunner.h" +#include "pass/UnusedOperandEliminationPass.h" +#include "../dumper/dot/DotDumper.h" +#include "../exec/MultiModelExecutors.h" +#include "../ir/OperationDumper.h" +#include "../ir/verifier/Verifier.h" + +#include "compiler/StaticShapeInferer.h" + +#include <misc/string_helpers.h> +#include <misc/polymorphic_downcast.h> + +namespace onert +{ +namespace compiler +{ + +MultiModelCompiler::MultiModelCompiler(const std::shared_ptr<ir::NNPkg> &nnpkg, + CompilerOptions *copts) + : _nnpkg{nnpkg}, _options{copts} +{ + // DO NOTHING +} + +std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void) +{ + /*************************************************** + * Prepare compilation phase + ***************************************************/ + { + if (!_options) + throw std::runtime_error{"Empty compile option"}; + + // Mode check + // TODO handle option for each model + if (_options->he_profiling_mode) + throw std::runtime_error("NYI: Profiling mode for multiple model is not supported yet"); + + _options->forceInternalOptions(); + _options->verboseOptions(); + } + + // NYI: allow one model compilation + auto const model_count = _nnpkg->model_count(); + for (uint16_t i = 0; i < model_count; i++) + { + if (!_nnpkg->model(ir::ModelIndex{i})->hasOnly<ir::Graph>()) + throw std::runtime_error("MultiModelCompiler can only compile models for inference."); + } + + for (uint16_t i = 0; i < model_count; i++) + { + _nnpkg->model(ir::ModelIndex{i})->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) { + auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph); + + // Mandatory passes + pass::PassRunner{} + .append(std::make_unique<pass::ConstantOutputPass>(subg)) + .append(std::make_unique<pass::OddOutputPass>(subg)) + .run(); + + // Optimizations + pass::PassRunner{}.append(std::make_unique<pass::UnusedOperandEliminationPass>(subg)).run(); + }); + } + + /*************************************************** + * Backend independent analysis & optimization phase + ***************************************************/ + // TODO Handle dump level for each model + auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options->graph_dump_level); + onert::dumper::dot::DotDumper dot_dumper(dump_level); + + // Tracing context + // TODO Support tracing_ctx for multiple model + std::unique_ptr<util::TracingCtx> tracing_ctx = nullptr; + + // Model edge context: copy model edge context + auto model_edges = std::make_unique<ir::ModelEdges>(_nnpkg->model_edges()); + + // Custom kernels + std::unordered_map<ir::ModelIndex, std::shared_ptr<backend::custom::IKernelBuilder>> + custom_kernel_builders; + for (uint16_t i = 0; i < model_count; i++) + { + auto const model_index = ir::ModelIndex{i}; + custom_kernel_builders[model_index] = _nnpkg->model(model_index)->getKernelBuilder(); + } + + // Lower: Assign backend + std::unordered_map<ir::ModelIndex, + std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::LoweredGraph>>> + lowered_subgs; + + for (uint16_t i = 0; i < model_count; i++) + { + auto const model_index = ir::ModelIndex{i}; + auto model = _nnpkg->model(model_index); + + model->iterate([&](const ir::SubgraphIndex &subg_index, ir::IGraph &graph) { + auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph); + + dot_dumper.dump(subg, + nnfw::misc::str("before_lower_model-", i, "-subg-", subg_index.value())); + // Lower: Assign backend + lowered_subgs[model_index][subg_index] = + std::make_unique<compiler::LoweredGraph>(subg, *_options); + // Set tracing_ctx for copied graph + if (tracing_ctx != nullptr) + tracing_ctx->setSubgraphIndex(&(lowered_subgs[model_index][subg_index]->graph()), + subg_index.value()); + }); + } + + _nnpkg.reset(); + + for (const auto &pair : lowered_subgs) + { + const auto &model_index = pair.first; + const auto &model_lsubg = pair.second; + + for (const auto &pair_inner : model_lsubg) + { + const auto &subg_index = pair_inner.first; + const auto &lowered_subg = pair_inner.second; + dot_dumper.dump(*lowered_subg, nnfw::misc::str("after_lower_model-", model_index.value(), + "-subg-", subg_index.value())); + } + } + + // Shape inference. + for (auto &&pair : lowered_subgs) + { + auto &model_lsubgs = pair.second; + // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called + // recursively + std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers = + createStaticShapeInferers(model_lsubgs); + + const auto primary_subg_idx = ir::SubgraphIndex{0}; + inferers.at(primary_subg_idx)->infer(); + + for (const auto &pair_inferer : inferers) + { + const auto inferer = pair_inferer.second.get(); + inferer->dump(); + } + } + + // Shape validation + // TODO Move shape independent feature check from ShapeValidator to OperationValidator + // TODO Move ShapeValidator into shape inference + // - Check input tensor shape validation + // - Check parameter value validation which valid value is depend on input tensor shape + // - Output tensor shape validation check is needless because + // static/dynamic shape inferer will make valid output shape + for (const auto &pair : lowered_subgs) + { + const auto &model_lsubgs = pair.second; + + for (const auto &pair_inner : model_lsubgs) + { + const auto &lowered_subg = pair_inner.second; + compiler::ShapeValidator{lowered_subg->graph()}(); + } + } + + /************************************************************* + * Backend independent analysis & optimization phase finished + *************************************************************/ + auto executors = std::make_shared<exec::MultiModelExecutors>(std::move(model_edges)); + for (auto &&pair : lowered_subgs) + { + auto const &model_index = pair.first; + auto &model_lsubgs = pair.second; + + for (auto &&pair_inner : model_lsubgs) + { + auto const subg_index = pair_inner.first; + auto &lowered_subg = pair_inner.second; + auto const indexed_ranks = lowered_subg->indexed_ranks(); + + ir::OperationDumper dumper("Executor generation of Subgraph " + + std::to_string(subg_index.value())); + lowered_subg->graph().operations().iterate( + [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); }); + + ExecutorFactoryArgs args; + args.tracing_ctx = tracing_ctx.get(); + args.options = _options; + args.model_index = model_index; + args.custom_kernel_builder = custom_kernel_builders[model_index]; + auto executor = std::unique_ptr<exec::IExecutor>{ + ExecutorFactory::get().create(std::move(lowered_subg), executors, args)}; + executor->setIndexedRanks(indexed_ranks); + executors->emplace(model_index, subg_index, std::move(executor)); + } + } + + /******************************** + * Code generation phase finished + ********************************/ + return std::make_shared<CompilerArtifact>(executors, std::move(tracing_ctx)); +} + +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/MultiModelCompiler.h b/runtime/onert/core/src/compiler/MultiModelCompiler.h new file mode 100644 index 000000000..7e202a71f --- /dev/null +++ b/runtime/onert/core/src/compiler/MultiModelCompiler.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file MultiModelCompiler.h + * @brief This file contains MultiModelCompiler class to define and run compilation phase + */ + +#ifndef __ONERT_COMPILER_MULTI_MODEL_COMPILER_H__ +#define __ONERT_COMPILER_MULTI_MODEL_COMPILER_H__ + +#include "compiler/CompilerOptions.h" +#include "compiler/ICompiler.h" +#include "ir/NNPkg.h" + +namespace onert +{ +namespace compiler +{ + +/** + * @brief Class to compile NN package + */ +class MultiModelCompiler final : public ICompiler +{ +public: + /** + * @brief Construct a new Compiler object for NN package + * @param[in] nnpkg NN package to compile + * @param[in] copts Compiler option for package + */ + MultiModelCompiler(const std::shared_ptr<ir::NNPkg> &nnpkg, CompilerOptions *copts); + + /** + * @brief Destroy the MultiModelCompiler object + */ + ~MultiModelCompiler() = default; + +public: + /** + * @brief Do compilation with the options + * + * @return std::shared_ptr<CompilerArtifact> MultiModelExecutors as a result of compilation + */ + std::shared_ptr<CompilerArtifact> compile(void); + +private: + std::shared_ptr<ir::NNPkg> _nnpkg; + CompilerOptions *_options; +}; + +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_MULTI_MODEL_COMPILER_H__ diff --git a/runtime/onert/core/src/ir/operation/LowerInfo.cc b/runtime/onert/core/src/compiler/OperationLowerInfo.cc index 249918bd6..e8a438130 100644 --- a/runtime/onert/core/src/ir/operation/LowerInfo.cc +++ b/runtime/onert/core/src/compiler/OperationLowerInfo.cc @@ -14,21 +14,18 @@ * limitations under the License. */ -#include "ir/operation/LowerInfo.h" +#include "compiler/OperationLowerInfo.h" namespace onert { -namespace ir -{ -namespace operation +namespace compiler { -LowerInfo::LowerInfo(const backend::Backend *backend, Layout layout) - : _permute_factor{backend, layout} +OperationLowerInfo::OperationLowerInfo(const backend::Backend *backend, ir::Layout layout) + : _permute_factor{backend, layout} { // DO NOTHING } -} // namespace operation -} // namespace ir +} // namespace compiler } // namespace onert diff --git a/runtime/onert/core/src/compiler/OperationValidator.cc b/runtime/onert/core/src/compiler/OperationValidator.cc deleted file mode 100644 index f7f659e3e..000000000 --- a/runtime/onert/core/src/compiler/OperationValidator.cc +++ /dev/null @@ -1,1053 +0,0 @@ -/* - * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "OperationValidator.h" - -#include <typeinfo> - -#include "ir/Graph.h" -#include "ir/operation/LowerInfo.h" - -#include "util/logging.h" -#include "util/Utils.h" - -#define OP_REQUIRES(EXP) \ - do \ - { \ - if (!(EXP)) \ - throw std::runtime_error("OperationValidator failed at line " + std::to_string(__LINE__)); \ - } while (0) - -namespace onert -{ -namespace compiler -{ - -OperationValidator::OperationValidator(const ir::Graph &graph) - : _graph{graph}, _ctx{graph.operands()}, _current_op_seq_layout{ir::Layout::UNKNOWN} -{ -} - -void OperationValidator::checkUnaryOp(const ir::Operation &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(0)}; - - // Check if I/O types match - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); - - if (_ctx.at(output_index).info().isDynamic()) - return; - - // Check if I/O shapes match - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} - -void OperationValidator::operator()() -{ - // There is no reason for each subgraph to have subgraphs since compiler has subgraphs when - // creating Compiler - assert(_graph.subgraphs() == nullptr); - - _current_op_seq_layout = _graph.layout(); - - _graph.operations().iterate( - [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); }); -} - -void OperationValidator::visit(const ir::operation::BatchMatMul &node) -{ - const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS)); - const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS)); - const auto out_index{node.getOutputs().at(0)}; - - // Constant lhs and rhs is not implemented yet - OP_REQUIRES(!_ctx.at(lhs_index).isConstant() && !_ctx.at(rhs_index).isConstant()); - - if (_ctx.at(out_index).info().isDynamic()) - return; - - OP_REQUIRES(_ctx.at(lhs_index).shape().rank() <= 4); - OP_REQUIRES(_ctx.at(rhs_index).shape().rank() <= 4); - OP_REQUIRES(_ctx.at(lhs_index).shape().rank() >= 2); - OP_REQUIRES(_ctx.at(rhs_index).shape().rank() >= 2); -} - -void OperationValidator::visit(const ir::operation::BatchToSpaceND &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)}; - const auto block_size_index{ - node.getInputs().at(ir::operation::BatchToSpaceND::Input::BLOCK_SIZE)}; - - const auto frontend_layout = _current_op_seq_layout; - const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout); - const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout); - - // All requirement as per NNAPI specification. - OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1); - - OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2); - - OP_REQUIRES(_ctx.at(block_size_index).isConstant()); - - OP_REQUIRES(input_shape.C == output_shape.C); -} - -void OperationValidator::visit(const ir::operation::Comparison &node) -{ - const auto output_index{node.getOutputs().at(0)}; - // This validator does not check shape. So checking isDynamic() is skipped. - - const auto lhs_index{node.getInputs().at(ir::operation::Comparison::Input::INPUT0)}; - const auto rhs_index{node.getInputs().at(ir::operation::Comparison::Input::INPUT1)}; - - OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type()); - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::BOOL8); -} - -void OperationValidator::visit(const ir::operation::Softmax &node) -{ - VERBOSE(Softmax) << "Configure SOFTMAX operation" << std::endl; - - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(0)}; - - OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank()); -} - -void OperationValidator::visit(const ir::operation::InstanceNorm &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)}; - const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)}; - const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)}; - - OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(ifm_index).shape() == _ctx.at(ofm_index).shape()); - OP_REQUIRES(_ctx.at(gamma_index).shape().rank() == 1); - OP_REQUIRES(_ctx.at(beta_index).shape().rank() == 1); -} - -void OperationValidator::visit(const ir::operation::Pool2D &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)}; - - OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4); -} - -void OperationValidator::visit(const ir::operation::Permute &node) -{ - VERBOSE(Permute) << "Configure Permute operation" << std::endl; - - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(0)}; - - OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank()); -} - -void OperationValidator::visit(const ir::operation::Reduce &node) -{ - VERBOSE(Permute) << "Configure " + node.name() + " operation" << std::endl; - - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)}; - const auto input_shape = _ctx.at(input_index).shape(); - const auto output_shape = _ctx.at(output_index).shape(); - - OP_REQUIRES(input_shape.rank() <= 4); - OP_REQUIRES(output_shape.rank() <= input_shape.rank()); - - // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only - // supports cases reducing height and width or reducing depth. - // TODO We have to support all cases of dimensions up to 4. - // For correct permuting, we have to set output's shape to be equal in dimension position of the - // input. But the positions of the same dimensions in the input and output may be set differently. - // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original - // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to - // extend it in 4 dimensions, it should be {1,1,3,5}. - // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of - // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the - // next operation is not desired. - if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank()) - { - if (output_shape.rank() == 2) - { - // Reducing HW - OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) && - input_shape.dim(3) == output_shape.dim(1)); - } - else if (output_shape.rank() == 3) - { - // Reducing C or - // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1) - OP_REQUIRES((input_shape.dim(0) == output_shape.dim(0) && - input_shape.dim(1) == output_shape.dim(1) && - input_shape.dim(2) == output_shape.dim(2)) || - (input_shape.dim(0) == output_shape.dim(0) && - (input_shape.dim(1) == output_shape.dim(1) || - input_shape.dim(2) == output_shape.dim(1)) && - input_shape.dim(3) == 1 && output_shape.dim(2) == 1)); - } - } -} - -void OperationValidator::visit(const ir::operation::Transpose &node) -{ - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)}; - const auto &perm{node.param().perm}; - - const auto &output_shape = _ctx.at(output_index).shape(); - const auto &input_shape = _ctx.at(input_index).shape(); - - OP_REQUIRES(input_shape.rank() == static_cast<int>(perm.size())); - OP_REQUIRES(input_shape.rank() == output_shape.rank()); -} - -void OperationValidator::visit(const ir::operation::RNN &node) -{ - // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn - // TODO Support dynamic rnn - const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto hidden_state_out_index{ - node.getOutputs().at(ir::operation::RNN::Output::HIDDEN_STATE_OUT)}; - - const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)}; - const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)}; - const auto recurrent_weights_index{ - node.getInputs().at(ir::operation::RNN::Input::RECURRENT_WEIGHTS)}; - const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)}; - const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)}; - - const auto batch_size = _ctx.at(output_index).shape().dim(0); - const auto num_units = _ctx.at(output_index).shape().dim(1); - - OP_REQUIRES(_ctx.at(output_index).shape().rank() == 2 && - _ctx.at(hidden_state_out_index).shape().rank() == 2 && - _ctx.at(input_index).shape().rank() == 2 && - _ctx.at(weights_index).shape().rank() == 2 && - _ctx.at(recurrent_weights_index).shape().rank() == 2 && - _ctx.at(hidden_state_in_index).shape().rank() == 2); - OP_REQUIRES(_ctx.at(bias_index).shape().rank() == 1); - - OP_REQUIRES(batch_size == _ctx.at(input_index).shape().dim(0) && - batch_size == _ctx.at(hidden_state_in_index).shape().dim(0) && - batch_size == _ctx.at(hidden_state_out_index).shape().dim(0)); - OP_REQUIRES(_ctx.at(input_index).shape().dim(1) == _ctx.at(weights_index).shape().dim(1)); - - OP_REQUIRES(num_units == _ctx.at(weights_index).shape().dim(0) && - num_units == _ctx.at(recurrent_weights_index).shape().dim(0) && - num_units == _ctx.at(bias_index).shape().dim(0)); - OP_REQUIRES(num_units == _ctx.at(output_index).shape().dim(1) && - num_units == _ctx.at(recurrent_weights_index).shape().dim(1) && - num_units == _ctx.at(hidden_state_in_index).shape().dim(1) && - num_units == _ctx.at(hidden_state_out_index).shape().dim(1)); -} - -void OperationValidator::visit(const ir::operation::SpaceToBatchND &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)}; - const auto block_size_index{ - node.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)}; - const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)}; - - const auto frontend_layout = _current_op_seq_layout; - const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout); - const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout); - - // All requirement as per NNAPI specification. - OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1); - OP_REQUIRES(_ctx.at(paddings_index).shape().rank() == 2); - - OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2); - OP_REQUIRES(_ctx.at(paddings_index).shape().dim(0) == 2); - OP_REQUIRES(_ctx.at(paddings_index).shape().dim(1) == 2); - - OP_REQUIRES(_ctx.at(block_size_index).isConstant()); - OP_REQUIRES(_ctx.at(paddings_index).isConstant()); - - OP_REQUIRES(input_shape.C == output_shape.C); -} - -void OperationValidator::visit(const ir::operation::SpaceToDepth &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)}; - - const auto frontend_layout = _current_op_seq_layout; - const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout); - const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout); - const auto block_size = node.param().block_size; - - // All assertions as per NNAPI specification. - OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4); - OP_REQUIRES((block_size >= 1) && (input_shape.H % block_size == 0) && - (input_shape.W % block_size == 0)); - OP_REQUIRES(input_shape.N == output_shape.N); - OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C); -} - -void OperationValidator::visit(const ir::operation::ElementwiseActivation &node) -{ - checkUnaryOp(node); -} - -void OperationValidator::visit(const ir::operation::ElementwiseBinary &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto lhs_index{node.getInputs().at(ir::operation::ElementwiseBinary::Input::LHS)}; - const auto rhs_index{node.getInputs().at(ir::operation::ElementwiseBinary::Input::RHS)}; - - OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type()); - OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(output_index).typeInfo().type()); -} - -void OperationValidator::visit(const ir::operation::ElementwiseUnary &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)}; - - OP_REQUIRES(node.getInputs().size() == 1); - OP_REQUIRES(node.getOutputs().size() == 1); - - // Check if I/O types match - if (node.param().op_type == ir::operation::ElementwiseUnary::Type::DEQUANTIZE) - { - OP_REQUIRES(_ctx.at(input_index).typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM); - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::FLOAT32); - } - else if (node.param().op_type == ir::operation::ElementwiseUnary::Type::QUANTIZE) - { - OP_REQUIRES(_ctx.at(input_index).typeInfo().type() == ir::DataType::FLOAT32); - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM); - } - else if (node.param().op_type != ir::operation::ElementwiseUnary::Type::CAST) - { - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); - } - - if (_ctx.at(output_index).info().isDynamic()) - return; - - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} - -void OperationValidator::visit(const ir::operation::EmbeddingLookup &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)}; - const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)}; - - const auto &output_obj = _ctx.at(output_index); - const auto &lookups_obj = _ctx.at(lookups_index); - const auto &values_obj = _ctx.at(values_index); - - // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying - // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729) - { - OP_REQUIRES(lookups_obj.typeInfo().type() == ir::DataType::INT32); - - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto &output_shape = output_obj.shape(); - const auto &lookups_shape = lookups_obj.shape(); - const auto &values_shape = values_obj.shape(); - - OP_REQUIRES(lookups_shape.rank() == 1); - OP_REQUIRES(values_shape.rank() >= 2); - - // output should be a n-D tensor with the same rank and shape as the values tensor, except for - // the first dimension which has the same size as lookups' only dimension. - OP_REQUIRES(output_shape.rank() == values_shape.rank()); - OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0)); - for (int n = 1; n < output_shape.rank(); ++n) - { - OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n)); - } - } -} - -void OperationValidator::visit(const ir::operation::ExpandDims &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::ExpandDims::Input::INPUT)}; - const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)}; - - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); - OP_REQUIRES(_ctx.at(axis_index).typeInfo().type() == ir::DataType::INT32); - - if (_ctx.at(axis_index).info().isDynamic()) - return; - OP_REQUIRES(_ctx.at(axis_index).shape().rank() <= 1); -} - -void OperationValidator::visit(const ir::operation::HashtableLookup &node) -{ - const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)}; - const auto hits_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::HITS)}; - - const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)}; - const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)}; - const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)}; - - const auto &output_obj = _ctx.at(output_index); - const auto &hits_obj = _ctx.at(hits_index); - - const auto &lookups_obj = _ctx.at(lookups_index); - const auto &keys_obj = _ctx.at(keys_index); - const auto &values_obj = _ctx.at(values_index); - - OP_REQUIRES(lookups_obj.typeInfo().type() == ir::DataType::INT32); - OP_REQUIRES(keys_obj.typeInfo().type() == ir::DataType::INT32); - OP_REQUIRES(hits_obj.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM); - - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto &output_shape = output_obj.shape(); - const auto &lookups_shape = lookups_obj.shape(); - const auto &keys_shape = keys_obj.shape(); - const auto &values_shape = values_obj.shape(); - - OP_REQUIRES(values_shape.rank() == output_shape.rank()); - OP_REQUIRES(lookups_shape.rank() == 1); - OP_REQUIRES(keys_shape.rank() == 1); - OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0)); - OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0)); -} - -void OperationValidator::visit(const ir::operation::TransposeConv &node) -{ - // param check - OP_REQUIRES((node.param().padding.type == ir::PaddingType::SAME) || - (node.param().padding.type == ir::PaddingType::VALID)); - - // shape check - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)}; - const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)}; - - // Only 4D tensors are supported - OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ifm_index).shape().rank()); - OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ker_index).shape().rank()); - - const auto frontend_layout = _current_op_seq_layout; - const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout); - const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout); - // The kernel has only IHWO layout on frontend - // So ker_shape is treated here below - // I -> N - // H -> H - // W -> W - // O -> C - const auto ker_shape = _ctx.at(ker_index).shape().asFeature(ir::Layout::NHWC); - - OP_REQUIRES(ifm_shape.N == ofm_shape.N); - OP_REQUIRES(ifm_shape.C == ker_shape.C); - OP_REQUIRES(ker_shape.N == ofm_shape.C); -} - -void OperationValidator::visit(const ir::operation::Gather &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)}; - const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)}; - - const auto ifm_shape = _ctx.at(ifm_index).shape(); - const auto indices_shape = _ctx.at(indices_index).shape(); - const auto ofm_shape = _ctx.at(ofm_index).shape(); - - OP_REQUIRES(ifm_shape.rank() <= 4); - OP_REQUIRES(indices_shape.rank() <= 3); - OP_REQUIRES(ofm_shape.rank() <= 4); -} - -void OperationValidator::visit(const ir::operation::DepthToSpace &node) -{ - // param check - int32_t block_size = node.param().block_size; - - OP_REQUIRES(block_size > 0); - - // shape check - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)}; - - const auto frontend_layout = _current_op_seq_layout; - const auto output_shape = _ctx.at(output_index).shape().asFeature(frontend_layout); - const auto input_shape = _ctx.at(input_index).shape().asFeature(frontend_layout); - - OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4); - - { - OP_REQUIRES(output_shape.N == input_shape.N); - OP_REQUIRES(output_shape.H == input_shape.H * block_size); - OP_REQUIRES(output_shape.W == input_shape.W * block_size); - OP_REQUIRES(input_shape.C % (block_size * block_size) == 0); - OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size)); - } -} - -void OperationValidator::visit(const ir::operation::Pack &node) -{ - // param check - const auto num{node.param().num}; - const auto axis{node.param().axis}; - OP_REQUIRES(num == static_cast<int32_t>(node.getInputs().size())); - - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - // shape check - const auto &output_shape = _ctx.at(output_index).shape(); - const auto output_rank = static_cast<int32_t>(output_shape.rank()); - - const auto input1_index{node.getInputs().at(0)}; - const auto input_shape = _ctx.at(input1_index).shape(); - - OP_REQUIRES(axis >= -output_rank && axis < output_rank); - for (const auto &index : node.getInputs()) - { - OP_REQUIRES(input_shape == _ctx.at(index).shape()); - } -} - -void OperationValidator::visit(const ir::operation::LSTM &node) -{ - // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn - // TODO Support dynamic rnn - const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto scratch_buffer_index{ - node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; - const auto output_state_out_index{ - node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; - const auto cell_state_out_index{ - node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; - - const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)}; - const auto input_to_input_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; - const auto input_to_forget_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)}; - const auto input_to_cell_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)}; - const auto input_to_output_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)}; - const auto recurrent_to_input_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; - const auto recurrent_to_forget_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)}; - const auto recurrent_to_cell_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)}; - const auto recurrent_to_output_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)}; - const auto cell_to_input_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; - const auto cell_to_forget_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; - const auto cell_to_output_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; - const auto input_gate_bias_index{ - node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)}; - const auto forget_gate_bias_index{ - node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)}; - const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)}; - const auto output_gate_bias_index{ - node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)}; - const auto projection_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; - const auto projection_bias_index{ - node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; - const auto output_state_in_index{ - node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)}; - const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)}; - - OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().rank() == 2 && - _ctx.at(output_state_out_index).shape().rank() == 2 && - _ctx.at(cell_state_out_index).shape().rank() == 2 && - _ctx.at(output_index).shape().rank() == 2 && - _ctx.at(input_index).shape().rank() == 2 && - _ctx.at(input_to_input_weights_index).shape().rank() == 2 && - _ctx.at(input_to_forget_weights_index).shape().rank() == 2 && - _ctx.at(input_to_cell_weights_index).shape().rank() == 2 && - _ctx.at(input_to_output_weights_index).shape().rank() == 2 && - _ctx.at(recurrent_to_input_weights_index).shape().rank() == 2 && - _ctx.at(recurrent_to_forget_weights_index).shape().rank() == 2 && - _ctx.at(recurrent_to_cell_weights_index).shape().rank() == 2 && - _ctx.at(recurrent_to_output_weights_index).shape().rank() == 2 && - _ctx.at(projection_weights_index).shape().rank() == 2 && - _ctx.at(output_state_in_index).shape().rank() == 2 && - _ctx.at(cell_state_in_index).shape().rank() == 2); - - OP_REQUIRES(_ctx.at(cell_to_input_weights_index).shape().rank() == 1 && - _ctx.at(cell_to_forget_weights_index).shape().rank() == 1 && - _ctx.at(cell_to_output_weights_index).shape().rank() == 1 && - _ctx.at(input_gate_bias_index).shape().rank() == 1 && - _ctx.at(forget_gate_bias_index).shape().rank() == 1 && - _ctx.at(cell_bias_index).shape().rank() == 1 && - _ctx.at(output_gate_bias_index).shape().rank() == 1 && - _ctx.at(projection_bias_index).shape().rank() == 1); - - // CIFG assertion - OP_REQUIRES((_ctx.at(input_to_input_weights_index).shape().dim(0) == 0 && - _ctx.at(input_to_input_weights_index).shape().dim(1) == 0 && - _ctx.at(recurrent_to_input_weights_index).shape().dim(0) == 0 && - _ctx.at(recurrent_to_input_weights_index).shape().dim(1) == 0 && - _ctx.at(input_gate_bias_index).shape().dim(0) == 0 && - _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0) || - (_ctx.at(input_to_input_weights_index).shape().dim(0) != 0 && - _ctx.at(input_to_input_weights_index).shape().dim(1) != 0 && - _ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 && - _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0 && - _ctx.at(input_gate_bias_index).shape().dim(0) != 0)); - - // Peephole assertion - OP_REQUIRES((_ctx.at(cell_to_forget_weights_index).shape().dim(0) == 0 && - _ctx.at(cell_to_output_weights_index).shape().dim(0) == 0) || - (_ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0 && - _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0)); - - bool has_input_to_input_weights = _ctx.at(input_to_input_weights_index).shape().dim(0) != 0 && - _ctx.at(input_to_input_weights_index).shape().dim(1) != 0; - bool has_recurrent_to_input_weights = - _ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 && - _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0; - bool has_input_gate_bias = _ctx.at(input_gate_bias_index).shape().dim(0) != 0; - bool has_cell_to_input_weights = _ctx.at(cell_to_input_weights_index).shape().dim(0) != 0; - bool has_cell_to_forget_weights = _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0; - bool has_cell_to_output_weights = _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0; - bool has_projection_weights = _ctx.at(projection_weights_index).shape().dim(0) != 0 && - _ctx.at(projection_weights_index).shape().dim(1) != 0; - bool has_projection_bias = _ctx.at(projection_bias_index).shape().dim(0); - - // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG). - // true: no CIFG - // false: CIFG - bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights; - - // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole. - // true: peephole - // false: no peephole - bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights; - - // NOTE The projection weights may have data but the projection bias may not. - bool has_projection_param = has_projection_weights; - - const auto batch_size = _ctx.at(input_index).shape().dim(0); - OP_REQUIRES(batch_size == _ctx.at(output_state_in_index).shape().dim(0) && - batch_size == _ctx.at(cell_state_in_index).shape().dim(0) && - batch_size == _ctx.at(scratch_buffer_index).shape().dim(0) && - batch_size == _ctx.at(output_state_out_index).shape().dim(0) && - batch_size == _ctx.at(cell_state_out_index).shape().dim(0) && - batch_size == _ctx.at(output_index).shape().dim(0)); - - const auto input_size = _ctx.at(input_index).shape().dim(1); - OP_REQUIRES(input_size == _ctx.at(input_to_forget_weights_index).shape().dim(1) && - input_size == _ctx.at(input_to_cell_weights_index).shape().dim(1) && - input_size == _ctx.at(input_to_output_weights_index).shape().dim(1)); - - const auto num_units = _ctx.at(cell_state_out_index).shape().dim(1); - OP_REQUIRES(num_units == _ctx.at(input_to_forget_weights_index).shape().dim(0) && - num_units == _ctx.at(input_to_cell_weights_index).shape().dim(0) && - num_units == _ctx.at(input_to_output_weights_index).shape().dim(0) && - num_units == _ctx.at(recurrent_to_forget_weights_index).shape().dim(0) && - num_units == _ctx.at(recurrent_to_cell_weights_index).shape().dim(0) && - num_units == _ctx.at(recurrent_to_output_weights_index).shape().dim(0) && - num_units == _ctx.at(forget_gate_bias_index).shape().dim(0) && - num_units == _ctx.at(cell_bias_index).shape().dim(0) && - num_units == _ctx.at(output_gate_bias_index).shape().dim(0) && - num_units == _ctx.at(cell_state_in_index).shape().dim(1) && - (((num_units * 3) == _ctx.at(scratch_buffer_index).shape().dim(1)) || - ((num_units * 4) == _ctx.at(scratch_buffer_index).shape().dim(1)))); - - const auto output_size = _ctx.at(output_index).shape().dim(1); - OP_REQUIRES(output_size == _ctx.at(recurrent_to_forget_weights_index).shape().dim(1) && - output_size == _ctx.at(recurrent_to_cell_weights_index).shape().dim(1) && - output_size == _ctx.at(recurrent_to_output_weights_index).shape().dim(1) && - output_size == _ctx.at(output_state_in_index).shape().dim(1) && - output_size == _ctx.at(output_state_out_index).shape().dim(1)); - - if (has_cifg_param) - { - OP_REQUIRES(input_size == _ctx.at(input_to_input_weights_index).shape().dim(1)); - OP_REQUIRES(num_units == _ctx.at(input_to_input_weights_index).shape().dim(0) && - num_units == _ctx.at(recurrent_to_input_weights_index).shape().dim(0) && - (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) || - _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* non-peephole */) && - num_units == _ctx.at(input_gate_bias_index).shape().dim(0)); - OP_REQUIRES(output_size == _ctx.at(recurrent_to_input_weights_index).shape().dim(1)); - OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights && - has_input_gate_bias); - if (has_cell_to_input_weights) - { - // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole. - OP_REQUIRES(has_peephole_param); - } - OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 4); - } - else - { - OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 3); - } - - if (has_peephole_param) - { - OP_REQUIRES(num_units == _ctx.at(cell_to_forget_weights_index).shape().dim(0) && - num_units == _ctx.at(cell_to_output_weights_index).shape().dim(0) && - (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) || - _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */)); - } - - if (has_projection_param) - { - OP_REQUIRES(num_units == _ctx.at(projection_weights_index).shape().dim(1)); - OP_REQUIRES(output_size == _ctx.at(projection_weights_index).shape().dim(0)); - if (has_projection_bias) - { - OP_REQUIRES(output_size == _ctx.at(projection_bias_index).shape().dim(0)); - } - } -} - -void OperationValidator::visit(const ir::operation::L2Normalization &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)}; - - auto ifm_shape = _ctx.at(ifm_index).shape(); - auto ofm_shape = _ctx.at(ofm_index).shape(); - - OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank()); - - for (auto i = 0; i < ifm_shape.rank(); i++) - { - OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i)); - } -} - -void OperationValidator::visit(const ir::operation::Unpack &node) -{ - const auto num{node.param().num}; - OP_REQUIRES(num == static_cast<int32_t>(node.getOutputs().size())); - const auto axis{node.param().axis}; - - const auto output_index{node.getInputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)}; - - const auto &input_shape = _ctx.at(input_index).shape(); - const auto input_rank = static_cast<int32_t>(input_shape.rank()); - - OP_REQUIRES(axis >= -input_rank && axis < input_rank); -} - -void OperationValidator::visit(const ir::operation::Pad &node) -{ - const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)}; - OP_REQUIRES(_ctx.at(pad_index).typeInfo().type() == ir::DataType::INT32); - - const auto output_index{node.getInputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)}; - - const auto &pad_shape = _ctx.at(pad_index).shape(); - const auto input_rank = static_cast<int32_t>(_ctx.at(input_index).shape().rank()); - - OP_REQUIRES(pad_shape.rank() == 2); - OP_REQUIRES(pad_shape.dim(0) == input_rank); - OP_REQUIRES(pad_shape.dim(1) == 2); - OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank()); -} - -void OperationValidator::visit(const ir::operation::Select &node) -{ - const auto output_index{node.getOutputs().at(0)}; - // This validator does not check shape. So checking isDynamic() is skipped. - - const auto condition_index{node.getInputs().at(ir::operation::Select::Input::CONDITION)}; - const auto input_true_index{node.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)}; - const auto input_false_index{node.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)}; - UNUSED_RELEASE(output_index); - UNUSED_RELEASE(input_true_index); - UNUSED_RELEASE(input_false_index); - - OP_REQUIRES(_ctx.at(condition_index).typeInfo().type() == ir::DataType::BOOL8); -} - -void OperationValidator::visit(const ir::operation::StridedSlice &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)}; - const auto starts_index{node.getInputs().at(ir::operation::StridedSlice::Input::STARTS)}; - const auto ends_index{node.getInputs().at(ir::operation::StridedSlice::Input::ENDS)}; - const auto strides_index{node.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)}; - - UNUSED_RELEASE(starts_index); - UNUSED_RELEASE(ends_index); - UNUSED_RELEASE(strides_index); - - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); - - if (_ctx.at(output_index).info().isDynamic()) - return; - - OP_REQUIRES(_ctx.at(input_index).shape().rank() <= 4); -} - -void OperationValidator::visit(const ir::operation::Split &node) -{ - const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)}; - - if (_ctx.at(input_index).info().isDynamic()) - return; - - const auto num_splits = node.param().num_splits; - const auto input_rank = _ctx.at(input_index).shape().rank(); - const auto axis = node.param().axis < 0 ? node.param().axis + input_rank : node.param().axis; - - OP_REQUIRES(num_splits > 0 && num_splits <= 0xFFFF); - OP_REQUIRES(axis >= 0 && axis < input_rank); - OP_REQUIRES(node.getOutputs().size() == static_cast<uint32_t>(num_splits)); - - OP_REQUIRES(_ctx.at(input_index).shape().dim(axis) % num_splits == 0); -} - -void OperationValidator::visit(const ir::operation::Shape &node) -{ - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(0)}; - UNUSED_RELEASE(input_index); - OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1); -} - -void OperationValidator::visit(const ir::operation::ResizeBilinear &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)}; - - if (_ctx.at(output_index).info().isDynamic()) - { - return; - } - OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4); - - auto align_corners = node.param().align_corners; - auto half_pixel_centers = node.param().half_pixel_centers; - - OP_REQUIRES(!align_corners || !half_pixel_centers); -} - -void OperationValidator::visit(const ir::operation::Reverse &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)}; - const auto axis_index{node.getInputs().at(ir::operation::Reverse::Input::AXIS)}; - - OP_REQUIRES(_ctx.at(axis_index).typeInfo().type() == ir::DataType::INT32); - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); - - if (_ctx.at(output_index).info().isDynamic()) - return; - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} - -void OperationValidator::visit(const ir::operation::If &) -{ - // TODO Add to validate with subgraphs -} - -void OperationValidator::visit(const ir::operation::While &node) -{ - // This validator does not check shape. So checking isDynamic() is skipped. - - OP_REQUIRES(node.getInputs().size() == node.getOutputs().size()); - // TODO Add to validate with subgraphs -} - -void OperationValidator::visit(const ir::operation::SquaredDifference &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)}; - const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)}; - - // Check for Type equivalence - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(lhs_index).typeInfo().type()); - OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type()); - - // Check for dimension constraints - if (_ctx.at(output_index).info().isDynamic()) - return; - - auto output_shape = _ctx.at(output_index).shape(); - auto lhs_shape = _ctx.at(lhs_index).shape(); - auto rhs_shape = _ctx.at(rhs_index).shape(); - // Check for output rank - OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank())); - auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank()); - - for (int idx = 1; idx <= min_rank; idx++) - { - int l_idx = lhs_shape.rank() - idx; - int r_idx = rhs_shape.rank() - idx; - int out_idx = output_shape.rank() - idx; - - OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0)); - - auto l_dims = lhs_shape.dim(l_idx); - auto r_dims = rhs_shape.dim(r_idx); - auto out_dims = output_shape.dim(out_idx); - - OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) || - ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims))); - } - auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape; - for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++) - { - int out_idx = output_shape.rank() - idx; - int tmp_idx = tmp_shape.rank() - idx; - - OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) && - (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx))); - } -} -void OperationValidator::visit(const ir::operation::Tile &node) -{ - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(0)}; - const auto multiple_index{node.getInputs().at(1)}; - - OP_REQUIRES(_ctx.at(multiple_index).shape().rank() == 1); - OP_REQUIRES(_ctx.at(multiple_index).shape().dim(0) == _ctx.at(input_index).shape().rank()); - OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank()); -} - -void OperationValidator::visit(const ir::operation::Range &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)}; - const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)}; - const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)}; - - // Check for dimension constraints - if (_ctx.at(output_index).info().isDynamic()) - return; - - OP_REQUIRES(_ctx.at(start_index).shape().rank() == 0); - OP_REQUIRES(_ctx.at(limit_index).shape().rank() == 0); - OP_REQUIRES(_ctx.at(delta_index).shape().rank() == 0); -} - -void OperationValidator::visit(const ir::operation::MatrixBandPart &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)}; - const auto num_lower_index{ - node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_LOWER_DIAG)}; - const auto num_upper_index{ - node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_UPPER_DIAG)}; - - // Check for dimension constraints - if (_ctx.at(output_index).info().isDynamic()) - return; - - OP_REQUIRES(_ctx.at(input_index).shape().rank() >= 2); // input must be more than 2 dim matrix - OP_REQUIRES(_ctx.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar - OP_REQUIRES(_ctx.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar -} - -void OperationValidator::visit(const ir::operation::LogSoftmax &node) -{ - VERBOSE(LogSoftmax) << "Configure LOGSOFTMAX operation" << std::endl; - - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(0)}; - - OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank()); -} - -} // namespace compiler -} // namespace onert diff --git a/runtime/onert/core/src/compiler/ParamChecker.h b/runtime/onert/core/src/compiler/ParamChecker.h deleted file mode 100644 index 61429d521..000000000 --- a/runtime/onert/core/src/compiler/ParamChecker.h +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file ParamChecker.h - * @brief This file contains ParamChecker to check\n - * operations' parameters are compilable at machine independent phase\n - * ex) Check param is constant - */ -#ifndef __ONERT_COMPILER_PARAM_CHECKER_H__ -#define __ONERT_COMPILER_PARAM_CHECKER_H__ - -#include "ir/OperationVisitor.h" - -namespace onert -{ -namespace ir -{ -class Graph; -} // namespace ir -} // namespace onert - -namespace onert -{ -namespace compiler -{ - -class ParamChecker : public ir::OperationVisitor -{ -public: - /** - * @brief Construct a new Param Checker object (deleted) - */ - ParamChecker(void) = delete; - /** - * @brief Construct a new Param Checker object - * @param[in] model Graph model to check - */ - ParamChecker(std::shared_ptr<ir::Graph> model) : _model{model} {} - -public: - /** - * @brief Run parameter analysis - */ - void operator()(); - /** - * @brief Return analysis result if model have non-const parameter - * @return @c true if there is non-const parameter, otherwise @c false - */ - bool haveNoneConstParam(void) { return _nonConstParam; } - -private: - const std::shared_ptr<ir::Graph> _model; - bool _nonConstParam{false}; -}; - -} // namespace compiler -} // namespace onert - -#endif // __ONERT_COMPILER_OPERATION_VALIDATOR_H__ diff --git a/runtime/onert/core/src/compiler/PermuteFactor.cc b/runtime/onert/core/src/compiler/PermuteFactor.cc new file mode 100644 index 000000000..f0081a2a4 --- /dev/null +++ b/runtime/onert/core/src/compiler/PermuteFactor.cc @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compiler/PermuteFactor.h" + +#include <assert.h> +#include <ostream> + +#include "backend/Backend.h" + +std::ostream &operator<<(std::ostream &os, const onert::compiler::PermuteFactor &obj) +{ + assert(obj.backend() && obj.backend()->config()); + return os << "(" << obj.backend()->config()->id() << "/" << to_string(obj.layout()) << ")"; +} diff --git a/runtime/onert/core/src/compiler/ShapeValidator.cc b/runtime/onert/core/src/compiler/ShapeValidator.cc new file mode 100644 index 000000000..0cd14c186 --- /dev/null +++ b/runtime/onert/core/src/compiler/ShapeValidator.cc @@ -0,0 +1,1132 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ShapeValidator.h" + +#include <typeinfo> + +#include "ir/Graph.h" +#include "util/logging.h" +#include "util/Utils.h" + +#define OP_REQUIRES(EXP) \ + do \ + { \ + if (!(EXP)) \ + throw std::runtime_error("ShapeValidator failed at line " + std::to_string(__LINE__)); \ + } while (0) + +namespace onert +{ +namespace compiler +{ + +ShapeValidator::ShapeValidator(const ir::Graph &graph) : _graph{graph} {} + +void ShapeValidator::checkUnaryOp(const ir::Operation &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(0)}; + + if (operands.at(output_index).info().isDynamic()) + return; + + // Check if I/O shapes match + OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape()); +} + +void ShapeValidator::operator()() +{ + _graph.operations().iterate( + [&](const ir::OperationIndex &, const ir::IOperation &node) { node.accept(*this); }); +} + +void ShapeValidator::visit(const ir::operation::BatchMatMul &node) +{ + const auto &operands = _graph.operands(); + const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS)); + const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS)); + const auto out_index{node.getOutputs().at(0)}; + + if (operands.at(out_index).info().isDynamic()) + return; + + OP_REQUIRES(operands.at(lhs_index).shape().rank() <= 4); + OP_REQUIRES(operands.at(rhs_index).shape().rank() <= 4); + OP_REQUIRES(operands.at(lhs_index).shape().rank() >= 2); + OP_REQUIRES(operands.at(rhs_index).shape().rank() >= 2); +} + +void ShapeValidator::visit(const ir::operation::BatchToSpaceND &node) +{ + const auto &operands = _graph.operands(); + const auto ofm_index{node.getOutputs().at(0)}; + if (operands.at(ofm_index).info().isDynamic()) + return; + + const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)}; + const auto block_size_index{ + node.getInputs().at(ir::operation::BatchToSpaceND::Input::BLOCK_SIZE)}; + + const auto frontend_layout = _graph.layout(); + const auto input_shape = operands.at(ifm_index).shape().asFeature(frontend_layout); + const auto output_shape = operands.at(ofm_index).shape().asFeature(frontend_layout); + + // All requirement as per NNAPI specification. + OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4); + OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4); + OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1); + + OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2); + + if (node.getInputs().size() != 2) + { + const auto crops_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::CROPS_DATA)}; + OP_REQUIRES(operands.at(crops_index).shape().rank() == 2); + OP_REQUIRES(operands.at(crops_index).shape().dim(0) == + (operands.at(ifm_index).shape().rank() - 2)); + OP_REQUIRES(operands.at(crops_index).shape().dim(1) == 2); + } + + OP_REQUIRES(input_shape.C == output_shape.C); +} + +void ShapeValidator::visit(const ir::operation::BCQFullyConnected &node) +{ + const auto &operands = _graph.operands(); + const auto ofm_index{node.getOutputs().at(0)}; + if (operands.at(ofm_index).info().isDynamic()) + return; + + const auto ifm_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)}; + const auto weight_scales_index{ + node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_SCALES)}; + const auto weight_binary_index{ + node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_BINARY)}; + const auto weight_cluster_index{ + node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)}; + const auto bias_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::BIAS)}; + + OP_REQUIRES(operands.at(ifm_index).shape().rank() == 2); + OP_REQUIRES(operands.at(ofm_index).shape().rank() == 2); + OP_REQUIRES(operands.at(weight_scales_index).shape().rank() == 1); + OP_REQUIRES(operands.at(weight_binary_index).shape().rank() == 2); + OP_REQUIRES(operands.at(weight_cluster_index).shape().rank() == 2); + + OP_REQUIRES(operands.at(ifm_index).shape().dim(1) == operands.at(ofm_index).shape().dim(1)); + + OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(0) > 0); + OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(1) == 2); + + // more shape validation will be done inside kernel. + + OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1); +} + +void ShapeValidator::visit(const ir::operation::BCQGather &node) +{ + const auto &operands = _graph.operands(); + const auto ofm_index{node.getOutputs().at(0)}; + if (operands.at(ofm_index).info().isDynamic()) + return; + + const auto indices_index{node.getInputs().at(ir::operation::BCQGather::Input::INDICES)}; + const auto input_binary_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)}; + const auto input_scales_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_SCALES)}; + const auto input_clusters_index{ + node.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)}; + + OP_REQUIRES(operands.at(indices_index).shape().rank() <= + 2); // TODO : support rank up to 4 or more + OP_REQUIRES(operands.at(input_binary_index).shape().rank() == 2); + OP_REQUIRES(operands.at(input_scales_index).shape().rank() == 1); + OP_REQUIRES(operands.at(input_clusters_index).shape().rank() == 2); + + OP_REQUIRES(operands.at(input_clusters_index).shape().dim(0) > 0); + OP_REQUIRES(operands.at(input_clusters_index).shape().dim(1) == 2); + + // more shape validation will be done inside kernel. +} + +void ShapeValidator::visit(const ir::operation::Conv2D &node) +{ + const auto &operands = _graph.operands(); + const auto ofm_index{node.getOutputs().at(0)}; + if (operands.at(ofm_index).info().isDynamic()) + return; + + const auto ifm_index{node.getInputs().at(ir::operation::Conv2D::Input::INPUT)}; + const auto ker_index{node.getInputs().at(ir::operation::Conv2D::Input::KERNEL)}; + const auto bias_index{node.getInputs().at(ir::operation::Conv2D::Input::BIAS)}; + + OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4); + OP_REQUIRES(operands.at(ker_index).shape().rank() == 4); + OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1); + OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4); +} + +void ShapeValidator::visit(const ir::operation::Comparison &) +{ + // TODO Shape validation of comparison +} + +void ShapeValidator::visit(const ir::operation::DepthwiseConv2D &node) +{ + const auto &operands = _graph.operands(); + const auto ofm_index{node.getOutputs().at(0)}; + if (operands.at(ofm_index).info().isDynamic()) + return; + + const auto ifm_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::INPUT)}; + const auto ker_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::KERNEL)}; + const auto bias_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::BIAS)}; + + OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4); + OP_REQUIRES(operands.at(ker_index).shape().rank() == 4); + OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1); + OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4); +} + +void ShapeValidator::visit(const ir::operation::FullyConnected &node) +{ + const auto &operands = _graph.operands(); + const auto ofm_index{node.getOutputs().at(0)}; + if (operands.at(ofm_index).info().isDynamic()) + return; + + const auto ifm_index{node.getInputs().at(ir::operation::FullyConnected::Input::INPUT)}; + const auto ker_index{node.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)}; + const auto bias_index{node.getInputs().at(ir::operation::FullyConnected::Input::BIAS)}; + + OP_REQUIRES(operands.at(ifm_index).shape().rank() >= 2); + OP_REQUIRES(operands.at(ker_index).shape().rank() == 2); + OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1); +} + +void ShapeValidator::visit(const ir::operation::Softmax &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + if (operands.at(output_index).info().isDynamic()) + return; + + const auto input_index{node.getInputs().at(0)}; + + OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank()); +} + +void ShapeValidator::visit(const ir::operation::InstanceNorm &node) +{ + const auto &operands = _graph.operands(); + const auto ofm_index{node.getOutputs().at(0)}; + if (operands.at(ofm_index).info().isDynamic()) + return; + + const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)}; + const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)}; + const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)}; + + OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4); + OP_REQUIRES(operands.at(ifm_index).shape() == operands.at(ofm_index).shape()); + OP_REQUIRES(operands.at(gamma_index).shape().rank() == 1); + OP_REQUIRES(operands.at(beta_index).shape().rank() == 1); +} + +void ShapeValidator::visit(const ir::operation::Pool2D &node) +{ + const auto &operands = _graph.operands(); + const auto ofm_index{node.getOutputs().at(0)}; + if (operands.at(ofm_index).info().isDynamic()) + return; + + const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)}; + + OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4); +} + +void ShapeValidator::visit(const ir::operation::Permute &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + if (operands.at(output_index).info().isDynamic()) + return; + + const auto input_index{node.getInputs().at(0)}; + + OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank()); +} + +void ShapeValidator::visit(const ir::operation::Reduce &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + if (operands.at(output_index).info().isDynamic()) + return; + + const auto &input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)}; + const auto &input_shape = operands.at(input_index).shape(); + const auto &output_shape = operands.at(output_index).shape(); + + OP_REQUIRES(input_shape.rank() <= 4); + OP_REQUIRES(output_shape.rank() <= input_shape.rank()); + + // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only + // supports cases reducing height and width or reducing depth. + // TODO We have to support all cases of dimensions up to 4. + // For correct permuting, we have to set output's shape to be equal in dimension position of the + // input. But the positions of the same dimensions in the input and output may be set differently. + // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original + // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to + // extend it in 4 dimensions, it should be {1,1,3,5}. + // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of + // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the + // next operation is not desired. + if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank()) + { + if (output_shape.rank() == 2) + { + // Reducing HW + OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) && + input_shape.dim(3) == output_shape.dim(1)); + } + else if (output_shape.rank() == 3) + { + // Reducing C or + // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1) + OP_REQUIRES( + (input_shape.dim(0) == output_shape.dim(0) && input_shape.dim(1) == output_shape.dim(1) && + input_shape.dim(2) == output_shape.dim(2)) || + (input_shape.dim(0) == output_shape.dim(0) && + (input_shape.dim(1) == output_shape.dim(1) || input_shape.dim(2) == output_shape.dim(1)) && + input_shape.dim(3) == 1 && output_shape.dim(2) == 1)); + } + } +} + +void ShapeValidator::visit(const ir::operation::Transpose &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + if (operands.at(output_index).info().isDynamic()) + return; + + const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)}; + const auto perm_index{node.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)}; + + const auto &output_shape = operands.at(output_index).shape(); + const auto &input_shape = operands.at(input_index).shape(); + + OP_REQUIRES(operands.at(perm_index).shape().num_elements() == 0 || + input_shape.rank() == + static_cast<int>(operands.at(perm_index).shape().num_elements())); + OP_REQUIRES(input_shape.rank() == output_shape.rank()); +} + +void ShapeValidator::visit(const ir::operation::RNN &node) +{ + // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn + // TODO Support dynamic rnn + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)}; + if (operands.at(output_index).info().isDynamic()) + return; + + const auto hidden_state_out_index{ + node.getOutputs().at(ir::operation::RNN::Output::HIDDEN_STATE_OUT)}; + + const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)}; + const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)}; + const auto recurrent_weights_index{ + node.getInputs().at(ir::operation::RNN::Input::RECURRENT_WEIGHTS)}; + const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)}; + const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)}; + + const auto batch_size = operands.at(output_index).shape().dim(0); + const auto num_units = operands.at(output_index).shape().dim(1); + + OP_REQUIRES(operands.at(output_index).shape().rank() == 2 && + operands.at(hidden_state_out_index).shape().rank() == 2 && + operands.at(input_index).shape().rank() == 2 && + operands.at(weights_index).shape().rank() == 2 && + operands.at(recurrent_weights_index).shape().rank() == 2 && + operands.at(hidden_state_in_index).shape().rank() == 2); + OP_REQUIRES(operands.at(bias_index).shape().rank() == 1); + + OP_REQUIRES(batch_size == operands.at(input_index).shape().dim(0) && + batch_size == operands.at(hidden_state_in_index).shape().dim(0) && + batch_size == operands.at(hidden_state_out_index).shape().dim(0)); + OP_REQUIRES(operands.at(input_index).shape().dim(1) == operands.at(weights_index).shape().dim(1)); + + OP_REQUIRES(num_units == operands.at(weights_index).shape().dim(0) && + num_units == operands.at(recurrent_weights_index).shape().dim(0) && + num_units == operands.at(bias_index).shape().dim(0)); + OP_REQUIRES(num_units == operands.at(output_index).shape().dim(1) && + num_units == operands.at(recurrent_weights_index).shape().dim(1) && + num_units == operands.at(hidden_state_in_index).shape().dim(1) && + num_units == operands.at(hidden_state_out_index).shape().dim(1)); +} + +void ShapeValidator::visit(const ir::operation::SpaceToBatchND &node) +{ + const auto &operands = _graph.operands(); + const auto ofm_index{node.getOutputs().at(0)}; + if (operands.at(ofm_index).info().isDynamic()) + return; + + const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)}; + const auto block_size_index{ + node.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)}; + const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)}; + + const auto frontend_layout = _graph.layout(); + const auto input_shape = operands.at(ifm_index).shape().asFeature(frontend_layout); + const auto output_shape = operands.at(ofm_index).shape().asFeature(frontend_layout); + + // All requirement as per NNAPI specification. + OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4); + OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4); + OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1); + OP_REQUIRES(operands.at(paddings_index).shape().rank() == 2); + + OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2); + OP_REQUIRES(operands.at(paddings_index).shape().dim(0) == 2); + OP_REQUIRES(operands.at(paddings_index).shape().dim(1) == 2); + + OP_REQUIRES(input_shape.C == output_shape.C); +} + +void ShapeValidator::visit(const ir::operation::SpaceToDepth &node) +{ + const auto &operands = _graph.operands(); + const auto ofm_index{node.getOutputs().at(0)}; + if (operands.at(ofm_index).info().isDynamic()) + return; + + const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)}; + + const auto frontend_layout = _graph.layout(); + const auto input_shape = operands.at(ifm_index).shape().asFeature(frontend_layout); + const auto output_shape = operands.at(ofm_index).shape().asFeature(frontend_layout); + const auto block_size = node.param().block_size; + + // All assertions as per NNAPI specification. + OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4); + OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4); + OP_REQUIRES((input_shape.H % block_size == 0) && (input_shape.W % block_size == 0)); + OP_REQUIRES(input_shape.N == output_shape.N); + OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C); +} + +void ShapeValidator::visit(const ir::operation::ElementwiseActivation &node) { checkUnaryOp(node); } + +void ShapeValidator::visit(const ir::operation::ElementwiseBinary &) +{ + // TODO Shape validation of ElementwiseBinary +} + +void ShapeValidator::visit(const ir::operation::ElementwiseUnary &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)}; + + if (operands.at(output_index).info().isDynamic()) + return; + + OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape()); +} + +void ShapeValidator::visit(const ir::operation::EmbeddingLookup &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)}; + const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)}; + + const auto &output_obj = operands.at(output_index); + const auto &lookups_obj = operands.at(lookups_index); + const auto &values_obj = operands.at(values_index); + + // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying + // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729) + { + if (operands.at(output_index).info().isDynamic()) + return; + + const auto &output_shape = output_obj.shape(); + const auto &lookups_shape = lookups_obj.shape(); + const auto &values_shape = values_obj.shape(); + + OP_REQUIRES(lookups_shape.rank() == 1); + OP_REQUIRES(values_shape.rank() >= 2); + + // output should be a n-D tensor with the same rank and shape as the values tensor, except for + // the first dimension which has the same size as lookups' only dimension. + OP_REQUIRES(output_shape.rank() == values_shape.rank()); + OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0)); + for (int n = 1; n < output_shape.rank(); ++n) + { + OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n)); + } + } +} + +void ShapeValidator::visit(const ir::operation::ExpandDims &node) +{ + const auto &operands = _graph.operands(); + const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)}; + + if (operands.at(axis_index).info().isDynamic()) + return; + OP_REQUIRES(operands.at(axis_index).shape().rank() <= 1); +} + +void ShapeValidator::visit(const ir::operation::HashtableLookup &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)}; + const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)}; + const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)}; + const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)}; + + const auto &output_obj = operands.at(output_index); + const auto &lookups_obj = operands.at(lookups_index); + const auto &keys_obj = operands.at(keys_index); + const auto &values_obj = operands.at(values_index); + + if (operands.at(output_index).info().isDynamic()) + return; + + const auto &output_shape = output_obj.shape(); + const auto &lookups_shape = lookups_obj.shape(); + const auto &keys_shape = keys_obj.shape(); + const auto &values_shape = values_obj.shape(); + + OP_REQUIRES(values_shape.rank() == output_shape.rank()); + OP_REQUIRES(lookups_shape.rank() == 1); + OP_REQUIRES(keys_shape.rank() == 1); + OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0)); + OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0)); +} + +void ShapeValidator::visit(const ir::operation::TransposeConv &node) +{ + // shape check + const auto &operands = _graph.operands(); + const auto ofm_index{node.getOutputs().at(0)}; + + if (operands.at(ofm_index).info().isDynamic()) + return; + + const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)}; + const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)}; + + // Only 4D tensors are supported + OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4); + OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ifm_index).shape().rank()); + OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ker_index).shape().rank()); + + const auto frontend_layout = _graph.layout(); + const auto ofm_shape = operands.at(ofm_index).shape().asFeature(frontend_layout); + const auto ifm_shape = operands.at(ifm_index).shape().asFeature(frontend_layout); + // The kernel has only IHWO layout on frontend + // So ker_shape is treated here below + // I -> N + // H -> H + // W -> W + // O -> C + const auto ker_shape = operands.at(ker_index).shape().asFeature(ir::Layout::NHWC); + + OP_REQUIRES(ifm_shape.N == ofm_shape.N); + OP_REQUIRES(ifm_shape.C == ker_shape.C); + OP_REQUIRES(ker_shape.N == ofm_shape.C); +} + +void ShapeValidator::visit(const ir::operation::Gather &node) +{ + const auto &operands = _graph.operands(); + const auto ofm_index{node.getOutputs().at(0)}; + if (operands.at(ofm_index).info().isDynamic()) + return; + + const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)}; + const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)}; + + const auto &ifm_shape = operands.at(ifm_index).shape(); + const auto &indices_shape = operands.at(indices_index).shape(); + const auto &ofm_shape = operands.at(ofm_index).shape(); + + OP_REQUIRES(ifm_shape.rank() <= 4); + OP_REQUIRES(indices_shape.rank() <= 3); + OP_REQUIRES(ofm_shape.rank() <= 4); +} + +void ShapeValidator::visit(const ir::operation::DepthToSpace &node) +{ + const auto &operands = _graph.operands(); + int32_t block_size = node.param().block_size; + + // shape check + const auto output_index{node.getOutputs().at(0)}; + if (operands.at(output_index).info().isDynamic()) + return; + + const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)}; + + const auto frontend_layout = _graph.layout(); + const auto output_shape = operands.at(output_index).shape().asFeature(frontend_layout); + const auto input_shape = operands.at(input_index).shape().asFeature(frontend_layout); + + OP_REQUIRES(operands.at(input_index).shape().rank() == 4); + OP_REQUIRES(operands.at(output_index).shape().rank() == 4); + + { + OP_REQUIRES(output_shape.N == input_shape.N); + OP_REQUIRES(output_shape.H == input_shape.H * block_size); + OP_REQUIRES(output_shape.W == input_shape.W * block_size); + OP_REQUIRES(input_shape.C % (block_size * block_size) == 0); + OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size)); + } +} + +void ShapeValidator::visit(const ir::operation::Pack &node) +{ + const auto &operands = _graph.operands(); + const auto axis{node.param().axis}; + const auto output_index{node.getOutputs().at(0)}; + if (operands.at(output_index).info().isDynamic()) + return; + + // shape check + const auto &output_shape = operands.at(output_index).shape(); + const auto output_rank = static_cast<int32_t>(output_shape.rank()); + + const auto input1_index{node.getInputs().at(0)}; + const auto &input_shape = operands.at(input1_index).shape(); + + OP_REQUIRES(axis >= -output_rank && axis < output_rank); + for (const auto &index : node.getInputs()) + { + OP_REQUIRES(input_shape == operands.at(index).shape()); + } +} + +void ShapeValidator::visit(const ir::operation::LSTM &node) +{ + // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn + // TODO Support dynamic rnn + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)}; + if (operands.at(output_index).info().isDynamic()) + return; + + const auto scratch_buffer_index{ + node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; // Optional + const auto output_state_out_index{ + node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; // Optional + const auto cell_state_out_index{ + node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; // Optional + + const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)}; + const auto input_to_input_weights_index{ + node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; // Optional + const auto input_to_forget_weights_index{ + node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)}; + const auto input_to_cell_weights_index{ + node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)}; + const auto input_to_output_weights_index{ + node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)}; + const auto recurrent_to_input_weights_index{ + node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // Optional + const auto recurrent_to_forget_weights_index{ + node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)}; + const auto recurrent_to_cell_weights_index{ + node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)}; + const auto recurrent_to_output_weights_index{ + node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)}; + const auto cell_to_input_weights_index{ + node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; // Optional + const auto cell_to_forget_weights_index{ + node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; // Optional + const auto cell_to_output_weights_index{ + node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; // Optional + const auto input_gate_bias_index{ + node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)}; // Optional + const auto forget_gate_bias_index{ + node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)}; + const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)}; + const auto output_gate_bias_index{ + node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)}; + const auto projection_weights_index{ + node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; // Optional + const auto projection_bias_index{ + node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; // Optional + const auto output_state_in_index{ + node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)}; + const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)}; + + OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank()); + for (int i = 0; i < operands.at(input_index).shape().rank() - 1; ++i) + { + OP_REQUIRES(operands.at(input_index).shape().dim(i) == + operands.at(output_index).shape().dim(i)); + } + OP_REQUIRES((operands.at(output_index).shape().rank() == 2 || + operands.at(output_index).shape().rank() == 3) && + (operands.at(input_index).shape().rank() == 2 || + operands.at(input_index).shape().rank() == 3) && + (!operands.exist(input_to_input_weights_index) || + operands.at(input_to_input_weights_index).shape().rank() == 2) && + operands.at(input_to_forget_weights_index).shape().rank() == 2 && + operands.at(input_to_cell_weights_index).shape().rank() == 2 && + operands.at(input_to_output_weights_index).shape().rank() == 2 && + (!operands.exist(recurrent_to_input_weights_index) || + operands.at(recurrent_to_input_weights_index).shape().rank() == 2) && + operands.at(recurrent_to_forget_weights_index).shape().rank() == 2 && + operands.at(recurrent_to_cell_weights_index).shape().rank() == 2 && + operands.at(recurrent_to_output_weights_index).shape().rank() == 2 && + (!operands.exist(projection_weights_index) || + operands.at(projection_weights_index).shape().rank() == 2) && + operands.at(output_state_in_index).shape().rank() == 2 && + operands.at(cell_state_in_index).shape().rank() == 2); + + OP_REQUIRES((!operands.exist(cell_to_input_weights_index) || + operands.at(cell_to_input_weights_index).shape().rank() == 1) && + (!operands.exist(cell_to_forget_weights_index) || + operands.at(cell_to_forget_weights_index).shape().rank() == 1) && + (!operands.exist(cell_to_output_weights_index) || + operands.at(cell_to_output_weights_index).shape().rank() == 1) && + (!operands.exist(input_gate_bias_index) || + operands.at(input_gate_bias_index).shape().rank() == 1) && + operands.at(forget_gate_bias_index).shape().rank() == 1 && + operands.at(cell_bias_index).shape().rank() == 1 && + operands.at(output_gate_bias_index).shape().rank() == 1 && + (!operands.exist(projection_bias_index) || + operands.at(projection_bias_index).shape().rank() == 1)); + + // CIFG assertion + OP_REQUIRES(((!operands.exist(input_to_input_weights_index) || + (operands.at(input_to_input_weights_index).shape().dim(0) == 0 && + operands.at(input_to_input_weights_index).shape().dim(1) == 0)) && + (!operands.exist(recurrent_to_input_weights_index) || + (operands.at(recurrent_to_input_weights_index).shape().dim(0) == 0 && + operands.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) && + (!operands.exist(input_gate_bias_index) || + operands.at(input_gate_bias_index).shape().dim(0) == 0) && + (!operands.exist(cell_to_input_weights_index) || + operands.at(cell_to_input_weights_index).shape().dim(0) == 0)) || + ((operands.exist(input_to_input_weights_index) && + (operands.at(input_to_input_weights_index).shape().dim(0) != 0 && + operands.at(input_to_input_weights_index).shape().dim(1) != 0)) && + (operands.exist(recurrent_to_input_weights_index) && + (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 && + operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) && + (operands.exist(input_gate_bias_index) && + operands.at(input_gate_bias_index).shape().dim(0) != 0))); + + // Peephole assertion + OP_REQUIRES(((!operands.exist(cell_to_forget_weights_index) || + operands.at(cell_to_forget_weights_index).shape().dim(0) == 0) && + (!operands.exist(cell_to_output_weights_index) || + operands.at(cell_to_output_weights_index).shape().dim(0) == 0)) || + ((operands.exist(cell_to_forget_weights_index) && + operands.at(cell_to_forget_weights_index).shape().dim(0) != 0) && + (operands.exist(cell_to_output_weights_index) && + operands.at(cell_to_output_weights_index).shape().dim(0) != 0))); + + bool has_input_to_input_weights = + operands.exist(input_to_input_weights_index) && + (operands.at(input_to_input_weights_index).shape().dim(0) != 0 && + operands.at(input_to_input_weights_index).shape().dim(1) != 0); + bool has_recurrent_to_input_weights = + operands.exist(recurrent_to_input_weights_index) && + (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 && + operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0); + bool has_input_gate_bias = + operands.exist(input_gate_bias_index) && operands.at(input_gate_bias_index).shape().dim(0) != 0; + bool has_cell_to_input_weights = operands.exist(cell_to_input_weights_index) && + operands.at(cell_to_input_weights_index).shape().dim(0) != 0; + bool has_cell_to_forget_weights = operands.exist(cell_to_forget_weights_index) && + operands.at(cell_to_forget_weights_index).shape().dim(0) != 0; + bool has_cell_to_output_weights = operands.exist(cell_to_output_weights_index) && + operands.at(cell_to_output_weights_index).shape().dim(0) != 0; + bool has_projection_weights = operands.exist(projection_weights_index) && + (operands.at(projection_weights_index).shape().dim(0) != 0 && + operands.at(projection_weights_index).shape().dim(1) != 0); + bool has_projection_bias = + operands.exist(projection_bias_index) && operands.at(projection_bias_index).shape().dim(0) != 0; + + // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG). + // true: no CIFG + // false: CIFG + bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights; + + // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole. + // true: peephole + // false: no peephole + bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights; + + // NOTE The projection weights may have data but the projection bias may not. + bool has_projection_param = has_projection_weights; + + const auto batch_size = (operands.at(input_index).shape().rank() == 3 && node.param().time_major) + ? operands.at(input_index).shape().dim(1) + : operands.at(input_index).shape().dim(0); + OP_REQUIRES(batch_size == operands.at(output_state_in_index).shape().dim(0) && + batch_size == operands.at(cell_state_in_index).shape().dim(0)); + + const auto input_size = + operands.at(input_index).shape().dim(operands.at(input_index).shape().rank() - 1); + OP_REQUIRES(input_size == operands.at(input_to_forget_weights_index).shape().dim(1) && + input_size == operands.at(input_to_cell_weights_index).shape().dim(1) && + input_size == operands.at(input_to_output_weights_index).shape().dim(1)); + + const auto num_units = operands.at(input_to_output_weights_index).shape().dim(0); + OP_REQUIRES(num_units == operands.at(input_to_cell_weights_index).shape().dim(0) && + num_units == operands.at(input_to_output_weights_index).shape().dim(0) && + num_units == operands.at(recurrent_to_forget_weights_index).shape().dim(0) && + num_units == operands.at(recurrent_to_cell_weights_index).shape().dim(0) && + num_units == operands.at(recurrent_to_output_weights_index).shape().dim(0) && + num_units == operands.at(forget_gate_bias_index).shape().dim(0) && + num_units == operands.at(cell_bias_index).shape().dim(0) && + num_units == operands.at(output_gate_bias_index).shape().dim(0) && + num_units == operands.at(cell_state_in_index).shape().dim(1)); + + const auto output_size = + operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1); + OP_REQUIRES(output_size == operands.at(recurrent_to_forget_weights_index).shape().dim(1) && + output_size == operands.at(recurrent_to_cell_weights_index).shape().dim(1) && + output_size == operands.at(recurrent_to_output_weights_index).shape().dim(1) && + output_size == operands.at(output_state_in_index).shape().dim(1)); + + if (has_cifg_param) + { + OP_REQUIRES(input_size == operands.at(input_to_input_weights_index).shape().dim(1)); + OP_REQUIRES( + num_units == operands.at(input_to_input_weights_index).shape().dim(0) && + num_units == operands.at(recurrent_to_input_weights_index).shape().dim(0) && + ((operands.exist(cell_to_input_weights_index) && + num_units == operands.at(cell_to_input_weights_index).shape().dim(0)) || + (!operands.exist(cell_to_input_weights_index) || + operands.at(cell_to_input_weights_index).shape().dim(0) == 0) /* non-peephole */) && + num_units == operands.at(input_gate_bias_index).shape().dim(0)); + OP_REQUIRES(output_size == operands.at(recurrent_to_input_weights_index).shape().dim(1)); + OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights && + has_input_gate_bias); + if (has_cell_to_input_weights) + { + // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole. + OP_REQUIRES(has_peephole_param); + } + if (operands.exist(scratch_buffer_index)) + OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 4); + } + else + { + if (operands.exist(scratch_buffer_index)) + OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 3); + } + + if (has_peephole_param) + { + OP_REQUIRES(num_units == operands.at(cell_to_forget_weights_index).shape().dim(0) && + num_units == operands.at(cell_to_output_weights_index).shape().dim(0) && + (num_units == operands.at(cell_to_input_weights_index).shape().dim(0) || + operands.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */)); + } + + if (has_projection_param) + { + OP_REQUIRES(num_units == operands.at(projection_weights_index).shape().dim(1)); + OP_REQUIRES(output_size == operands.at(projection_weights_index).shape().dim(0)); + if (has_projection_bias) + { + OP_REQUIRES(output_size == operands.at(projection_bias_index).shape().dim(0)); + } + } + + if (operands.exist(scratch_buffer_index)) + { + OP_REQUIRES(operands.at(scratch_buffer_index).shape().rank() == 2); + OP_REQUIRES(batch_size == operands.at(scratch_buffer_index).shape().dim(0)); + } + + if (operands.exist(output_state_out_index)) + { + OP_REQUIRES(operands.at(output_state_out_index).shape().rank() == 2); + OP_REQUIRES(batch_size == operands.at(output_state_out_index).shape().dim(0)); + OP_REQUIRES(output_size == operands.at(output_state_out_index).shape().dim(1)); + } + + if (operands.exist(cell_state_out_index)) + { + OP_REQUIRES(operands.at(cell_state_out_index).shape().rank() == 2); + OP_REQUIRES(batch_size == operands.at(cell_state_out_index).shape().dim(0)); + OP_REQUIRES(num_units == operands.at(cell_state_out_index).shape().dim(1)); + } +} + +void ShapeValidator::visit(const ir::operation::L2Normalization &node) +{ + const auto &operands = _graph.operands(); + const auto ofm_index{node.getOutputs().at(0)}; + if (operands.at(ofm_index).info().isDynamic()) + return; + + const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)}; + + auto ifm_shape = operands.at(ifm_index).shape(); + auto ofm_shape = operands.at(ofm_index).shape(); + + OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank()); + + for (auto i = 0; i < ifm_shape.rank(); i++) + { + OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i)); + } +} + +void ShapeValidator::visit(const ir::operation::Unpack &node) +{ + const auto &operands = _graph.operands(); + const auto axis{node.param().axis}; + const auto output_index{node.getInputs().at(0)}; + if (operands.at(output_index).info().isDynamic()) + return; + + const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)}; + + const auto &input_shape = operands.at(input_index).shape(); + const auto input_rank = static_cast<int32_t>(input_shape.rank()); + + OP_REQUIRES(axis >= -input_rank && axis < input_rank); +} + +void ShapeValidator::visit(const ir::operation::Pad &node) +{ + const auto &operands = _graph.operands(); + const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)}; + OP_REQUIRES(operands.at(pad_index).typeInfo().type() == ir::DataType::INT32); + + const auto output_index{node.getInputs().at(0)}; + if (operands.at(output_index).info().isDynamic()) + return; + + const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)}; + + const auto &pad_shape = operands.at(pad_index).shape(); + const auto input_rank = static_cast<int32_t>(operands.at(input_index).shape().rank()); + + OP_REQUIRES(pad_shape.rank() == 2); + OP_REQUIRES(pad_shape.dim(0) == input_rank); + OP_REQUIRES(pad_shape.dim(1) == 2); + OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank()); +} + +void ShapeValidator::visit(const ir::operation::Select &) +{ + // TODO Shape validation of select +} + +void ShapeValidator::visit(const ir::operation::StridedSlice &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)}; + + if (operands.at(output_index).info().isDynamic()) + return; + + OP_REQUIRES(operands.at(input_index).shape().rank() <= 4); +} + +void ShapeValidator::visit(const ir::operation::Split &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + if (operands.at(output_index).info().isDynamic()) + return; + + const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)}; + const auto axis_index{node.getInputs().at(ir::operation::Split::Input::AXIS)}; + + const auto num_splits = node.param().num_splits; + const auto input_rank = operands.at(input_index).shape().rank(); + auto axis = *reinterpret_cast<const int32_t *>(operands.at(axis_index).data()->base()); + axis = axis < 0 ? axis + input_rank : axis; + + OP_REQUIRES(axis >= 0 && axis < input_rank); + OP_REQUIRES(operands.at(input_index).shape().dim(axis) % num_splits == 0); +} + +void ShapeValidator::visit(const ir::operation::Shape &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + if (operands.at(output_index).info().isDynamic()) + return; + + const auto input_index{node.getInputs().at(0)}; + UNUSED_RELEASE(input_index); + OP_REQUIRES(operands.at(output_index).shape().rank() == 1); +} + +void ShapeValidator::visit(const ir::operation::ResizeBilinear &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)}; + + if (operands.at(output_index).info().isDynamic()) + { + return; + } + OP_REQUIRES(operands.at(input_index).shape().rank() == 4); + OP_REQUIRES(operands.at(output_index).shape().rank() == 4); +} + +void ShapeValidator::visit(const ir::operation::Reverse &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)}; + + if (operands.at(output_index).info().isDynamic()) + return; + OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape()); +} + +void ShapeValidator::visit(const ir::operation::If &) +{ + // TODO Add to validate with subgraphs +} + +void ShapeValidator::visit(const ir::operation::While &) +{ + // This validator does not check shape. So checking isDynamic() is skipped. + // TODO Add to validate with subgraphs +} + +void ShapeValidator::visit(const ir::operation::SquaredDifference &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)}; + const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)}; + + // Check for dimension constraints + if (operands.at(output_index).info().isDynamic()) + return; + + auto output_shape = operands.at(output_index).shape(); + auto lhs_shape = operands.at(lhs_index).shape(); + auto rhs_shape = operands.at(rhs_index).shape(); + // Check for output rank + OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank())); + auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank()); + + for (int idx = 1; idx <= min_rank; idx++) + { + int l_idx = lhs_shape.rank() - idx; + int r_idx = rhs_shape.rank() - idx; + int out_idx = output_shape.rank() - idx; + + OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0)); + + auto l_dims = lhs_shape.dim(l_idx); + auto r_dims = rhs_shape.dim(r_idx); + auto out_dims = output_shape.dim(out_idx); + + OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) || + ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims))); + } + auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape; + for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++) + { + int out_idx = output_shape.rank() - idx; + int tmp_idx = tmp_shape.rank() - idx; + + OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) && + (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx))); + } +} +void ShapeValidator::visit(const ir::operation::Tile &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + if (operands.at(output_index).info().isDynamic()) + return; + + const auto input_index{node.getInputs().at(0)}; + const auto multiple_index{node.getInputs().at(1)}; + + OP_REQUIRES(operands.at(multiple_index).shape().rank() == 1); + OP_REQUIRES(operands.at(multiple_index).shape().dim(0) == + operands.at(input_index).shape().rank()); + OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank()); +} + +void ShapeValidator::visit(const ir::operation::Range &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)}; + const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)}; + const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)}; + + // Check for dimension constraints + if (operands.at(output_index).info().isDynamic()) + return; + + OP_REQUIRES(operands.at(start_index).shape().rank() == 0); + OP_REQUIRES(operands.at(limit_index).shape().rank() == 0); + OP_REQUIRES(operands.at(delta_index).shape().rank() == 0); +} + +void ShapeValidator::visit(const ir::operation::MatrixBandPart &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)}; + const auto num_lower_index{ + node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_LOWER_DIAG)}; + const auto num_upper_index{ + node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_UPPER_DIAG)}; + + // Check for dimension constraints + if (operands.at(output_index).info().isDynamic()) + return; + + OP_REQUIRES(operands.at(input_index).shape().rank() >= 2); // input must be more than 2 dim matrix + OP_REQUIRES(operands.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar + OP_REQUIRES(operands.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar +} + +void ShapeValidator::visit(const ir::operation::LogSoftmax &node) +{ + const auto &operands = _graph.operands(); + const auto output_index{node.getOutputs().at(0)}; + if (operands.at(output_index).info().isDynamic()) + return; + + const auto input_index{node.getInputs().at(0)}; + + OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank()); +} + +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/OperationValidator.h b/runtime/onert/core/src/compiler/ShapeValidator.h index deb6357bb..da83a432a 100644 --- a/runtime/onert/core/src/compiler/OperationValidator.h +++ b/runtime/onert/core/src/compiler/ShapeValidator.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef __ONERT_COMPILER_OPERATION_VALIDATOR_H__ -#define __ONERT_COMPILER_OPERATION_VALIDATOR_H__ +#ifndef __ONERT_COMPILER_SHAPE_VALIDATOR_H__ +#define __ONERT_COMPILER_SHAPE_VALIDATOR_H__ #include "ir/Layout.h" #include "ir/OperationVisitor.h" @@ -34,19 +34,29 @@ namespace onert namespace compiler { -class OperationValidator : public ir::OperationVisitor +class ShapeValidator : public ir::OperationVisitor { public: - OperationValidator(void) = delete; - OperationValidator(const ir::Graph &graph); + ShapeValidator(void) = delete; + ShapeValidator(const ir::Graph &graph); + ShapeValidator(const ShapeValidator &) = delete; + ShapeValidator(ShapeValidator &&) = delete; + ~ShapeValidator() = default; public: + ShapeValidator &operator=(const ShapeValidator &) = delete; + ShapeValidator &operator=(ShapeValidator &&) = delete; void operator()(); public: void visit(const ir::operation::BatchMatMul &node) override; void visit(const ir::operation::BatchToSpaceND &node) override; + void visit(const ir::operation::BCQFullyConnected &node) override; + void visit(const ir::operation::BCQGather &node) override; + void visit(const ir::operation::Conv2D &node) override; void visit(const ir::operation::Comparison &node) override; + void visit(const ir::operation::DepthwiseConv2D &node) override; + void visit(const ir::operation::FullyConnected &node) override; void visit(const ir::operation::Softmax &node) override; void visit(const ir::operation::InstanceNorm &node) override; void visit(const ir::operation::Permute &node) override; @@ -88,13 +98,10 @@ private: void checkUnaryOp(const ir::Operation &node); private: - // TODO Remove _ctx field const ir::Graph &_graph; - const ir::Operands &_ctx; - ir::Layout _current_op_seq_layout; }; } // namespace compiler } // namespace onert -#endif // __ONERT_COMPILER_OPERATION_VALIDATOR_H__ +#endif // __ONERT_COMPILER_SHAPE_VALIDATOR_H__ diff --git a/runtime/onert/core/src/compiler/StaticShapeInference.cc b/runtime/onert/core/src/compiler/StaticShapeInference.cc deleted file mode 100644 index 4eba1ff49..000000000 --- a/runtime/onert/core/src/compiler/StaticShapeInference.cc +++ /dev/null @@ -1,1096 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "compiler/StaticShapeInference.h" -#include "util/ShapeInference.h" -#include "util/logging.h" - -#include <sstream> - -namespace onert -{ -namespace compiler -{ - -bool StaticShapeInferer::infer(const ir::OpSequence &op_seq) -{ - bool has_dynamic_tensor = false; - - for (const auto &operation_idx : op_seq.operations()) - { - auto &op = _operations.at(operation_idx); - auto opcode = op.opcode(); - - _return_has_dynamic_tensor = false; // this is used as a return value inside operation's visit() - - // IF: need shape inference for then, else - // While: need shape inference for condition, body - if (opcode == ir::OpCode::If || opcode == ir::OpCode::While) - { - op.accept(*this); - } - else - { - _return_has_dynamic_tensor = checkDynamicInput(op); - - if (_return_has_dynamic_tensor) - { - setDynamicOutput(op); - } - else - { - op.accept(*this); - } - } - - has_dynamic_tensor = has_dynamic_tensor || _return_has_dynamic_tensor; - } - - return has_dynamic_tensor; -} - -bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op) -{ - for (auto input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED) - { - if (_operands.at(input_idx).info().isDynamic()) - { - return true; - } - } - - return false; -} - -void StaticShapeInferer::setDynamicOutput(const ir::Operation &op) -{ - for (auto output_idx : op.getOutputs()) - { - _operands.at(output_idx).info().setDynamic(); - } -} - -void StaticShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op, - const ir::OperandIndex lhs_idx, - const ir::OperandIndex rhs_idx) -{ - const auto &lhs = _operands.at(lhs_idx); - const auto &rhs = _operands.at(rhs_idx); - - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - // re-sizing output shape - ir::Shape new_shape = shape_inference::inferEltwiseShape(lhs.info().shape(), rhs.info().shape()); - output.info().shape(new_shape); -} - -void StaticShapeInferer::handleSimpleUnaryOp(const ir::Operation &op, - const ir::OperandIndex input_idx) -{ - const auto &input = _operands.at(input_idx); - - // get mutable output operand - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - // re-sizing output shape - ir::Shape new_shape = input.info().shape(); - output.info().shape(new_shape); -} - -void StaticShapeInferer::dump() -{ - auto get_shape_str = [](const ir::Shape &shape) { - std::stringstream sstream; - sstream << "shape : {"; - for (int i = 0; i < shape.rank(); i++) - { - if (i == 0) - sstream << shape.dim(i); - else - sstream << " " << shape.dim(i); - } - sstream << "}"; - return sstream.str(); - }; - - for (const auto &pair : _lowered_subgs) - { - const auto index = pair.first; - const auto &lowered_subg = pair.second; - VERBOSE(StaticShapeInferer) << "SubGraph #" << index.value() << std::endl; - lowered_subg->graph().operands().iterate( - [&](const ir::OperandIndex &ind, const ir::Operand &operand) { - VERBOSE(StaticShapeInferer) << "Operand #" << ind.value() << ", " - << (operand.info().isDynamic() ? "Dynamic" : "Static") << ", " - << get_shape_str(operand.info().shape()) << std::endl; - }); - } -} - -void StaticShapeInferer::visit(const ir::operation::ArgMax &op) -{ - const auto input_idx{op.getInputs().at(ir::operation::ArgMax::Input::INPUT)}; - const auto &input = _operands.at(input_idx); - - // get mutable output operand - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - const auto rank = input.info().shape().rank(); - const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis); - - assert(0 <= axis && axis < rank); - - // re-sizing output shape - ir::Shape new_shape = shape_inference::inferArgMaxShape(input.info().shape(), axis, rank); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::BatchMatMul &op) -{ - const auto lhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::LHS); - const auto rhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::RHS); - const auto output_index = op.getOutputs().at(0); - const auto lhs = _operands.at(lhs_index); - const auto rhs = _operands.at(rhs_index); - auto &output = _operands.at(output_index); - auto new_shape = shape_inference::inferBatchMatMulShape(lhs.shape(), rhs.shape(), op.param()); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::BinaryArithmetic &op) -{ - handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::BinaryArithmetic::Input::LHS), - op.getInputs().at(ir::operation::BinaryArithmetic::Input::RHS)); -} - -void StaticShapeInferer::visit(const ir::operation::BroadcastTo &op) -{ - // get mutable output operand - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - const auto shape_idx{op.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)}; - const auto &shape = _operands.at(shape_idx); - - if (!shape.isConstant()) - { - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - return; - } - - // assert(shape.typeInfo().type() == ir::DataType::INT32); - auto shape_buffer = reinterpret_cast<const int32_t *>(shape.data()->base()); - - // re-sizing output shape - ir::Shape new_shape = shape_inference::inferBroadcastToShape(shape.info().shape(), shape_buffer); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Comparison &op) -{ - handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Comparison::Input::INPUT0), - op.getInputs().at(ir::operation::Comparison::Input::INPUT1)); -} - -void StaticShapeInferer::visit(const ir::operation::Concat &op) -{ - const auto input_count = op.getInputs().size(); - - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - shape_inference::Shapes input_shapes; - for (uint32_t i = 0; i < input_count; i++) - { - const auto input_idx{op.getInputs().at(i)}; - const auto &input = _operands.at(input_idx); - input_shapes.emplace_back(input.shape()); - } - - ir::Shape out_shape = shape_inference::inferConcatShape(input_shapes, op.param()); - - // re-sizing output shape - output.info().shape(out_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Conv2D &op) -{ - const auto input_idx{op.getInputs().at(ir::operation::Conv2D::Input::INPUT)}; - const auto &input = _operands.at(input_idx); - const auto ker_idx{op.getInputs().at(ir::operation::Conv2D::Input::KERNEL)}; - const auto &ker = _operands.at(ker_idx); - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - // re-sizing output shape - ir::Shape new_shape = - shape_inference::inferConv2DShape(input.info().shape(), ker.info().shape(), op.param()); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::ElementwiseActivation &op) -{ - handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseActivation::Input::INPUT)); -} - -void StaticShapeInferer::visit(const ir::operation::ElementwiseBinary &op) -{ - handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::ElementwiseBinary::Input::LHS), - op.getInputs().at(ir::operation::ElementwiseBinary::Input::RHS)); -} - -void StaticShapeInferer::visit(const ir::operation::ElementwiseUnary &op) -{ - handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)); -} - -void StaticShapeInferer::visit(const ir::operation::ExpandDims &op) -{ - const auto input_idx{op.getInputs().at(ir::operation::ExpandDims::Input::INPUT)}; - const auto &input = _operands.at(input_idx); - const auto axis_idx{op.getInputs().at(ir::operation::ExpandDims::Input::AXIS)}; - const auto &axis = _operands.at(axis_idx); - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - if (!axis.isConstant()) - { - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - return; - } - - // even when axis is constant, output shape should be recalculated since user might call - // nnfw_set_input_tensorinfo(input, some_new_shape) - auto axis_buf = reinterpret_cast<const int32_t *>(axis.data()->base()); - assert(axis_buf); - - // re-sizing output shape - ir::Shape new_shape = shape_inference::inferExpandDimsShape(input.info().shape(), axis_buf[0]); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Fill &op) -{ - const auto input_idx{op.getInputs().at(ir::operation::Fill::Input::INPUT)}; - const auto &input = _operands.at(input_idx); - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - if (!input.isConstant()) - { - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - return; - } - - assert(input.typeInfo().type() == ir::DataType::INT32); - - auto input_buf = reinterpret_cast<const int32_t *>(input.data()->base()); - assert(input_buf); - - // re-sizing output shape - ir::Shape new_shape = shape_inference::inferFillShape(input.info().shape(), input_buf); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::FullyConnected &op) -{ - const auto input_idx{op.getInputs().at(ir::operation::FullyConnected::Input::INPUT)}; - const auto &input = _operands.at(input_idx); - - const auto ker_idx{op.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)}; - const auto &ker = _operands.at(ker_idx); - - // get mutable output operand - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - // re-sizing output shape - ir::Shape new_shape = - shape_inference::inferFullyConnectedShape(input.info().shape(), ker.info().shape()); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::FusedBatchNorm &op) -{ - handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::FusedBatchNorm::Input::INPUT)); -} - -void StaticShapeInferer::visit(const ir::operation::Gather &op) -{ - const auto input_idx{op.getInputs().at(ir::operation::Gather::Input::INPUT)}; - const auto &input = _operands.at(input_idx); - - // get mutable output operand - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - const auto indices_idx{op.getInputs().at(ir::operation::Gather::Input::INDICES)}; - const auto &indices = _operands.at(indices_idx); - const auto rank = input.info().shape().rank(); - const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis); - - assert(0 <= axis && axis < rank); - - // re-sizing output shape - ir::Shape new_shape = - shape_inference::inferGatherShape(input.info().shape(), indices.info().shape(), axis, rank); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::If &op) -{ - auto &then_graph = _lowered_subgs.at(op.param().then_subg_index)->graph(); - auto &else_graph = _lowered_subgs.at(op.param().else_subg_index)->graph(); - const std::vector<ir::OperandIndex> inputs{op.getInputs().begin() + 1, op.getInputs().end()}; - const auto &outputs = op.getOutputs(); - - // re-sizing input shapes of then subgraph - const auto &then_inputs = then_graph.getInputs(); - assert(inputs.size() == then_inputs.size()); - for (size_t i = 0; i < inputs.size(); ++i) - { - auto &then_input = then_graph.operands().at(then_inputs.at(i)); - if (_operands.at(inputs.at(i)).info().isDynamic()) - { - then_input.info().setDynamic(); - } - else - { - auto new_shape = _operands.at(inputs.at(i)).info().shape(); - then_input.info().shape(new_shape); - } - } - - // re-sizing input shapes of else subgraph - const auto &else_inputs = else_graph.getInputs(); - assert(inputs.size() == else_inputs.size()); - for (size_t i = 0; i < inputs.size(); ++i) - { - auto &else_input = else_graph.operands().at(else_inputs.at(i)); - if (_operands.at(inputs.at(i)).info().isDynamic()) - { - else_input.info().setDynamic(); - } - else - { - const auto &new_shape = _operands.at(inputs.at(i)).info().shape(); - else_input.info().shape(new_shape); - } - } - - // re-sizing operands of then subgraph - StaticShapeInferer then_inferer(op.param().then_subg_index, _lowered_subgs); - _lowered_subgs.at(op.param().then_subg_index) - ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) { - bool has_dynamic_tensor = then_inferer.infer(op_seq); - op_seq.has_dynamic_tensor(has_dynamic_tensor); - }); - - // re-sizing operands of else subgraph - StaticShapeInferer else_inferer(op.param().else_subg_index, _lowered_subgs); - _lowered_subgs.at(op.param().else_subg_index) - ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) { - bool has_dynamic_tensor = else_inferer.infer(op_seq); - op_seq.has_dynamic_tensor(has_dynamic_tensor); - }); - - // re-sizing output shapes - const auto &then_outputs = _lowered_subgs.at(op.param().then_subg_index)->graph().getOutputs(); - const auto &else_outputs = _lowered_subgs.at(op.param().else_subg_index)->graph().getOutputs(); - assert(outputs.size() == then_outputs.size()); - assert(outputs.size() == else_outputs.size()); - for (size_t i = 0; i < outputs.size(); ++i) - { - const auto &then_output = then_graph.operands().at(then_outputs.at(i)); - const auto &else_output = else_graph.operands().at(else_outputs.at(i)); - auto &output = _operands.at(outputs.at(i)); - if (!then_output.info().isDynamic() && !else_output.info().isDynamic() && - then_output.shape() == else_output.shape()) - { - output.info().shape(then_output.shape()); - } - else - { - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - } - } -} - -void StaticShapeInferer::visit(const ir::operation::L2Normalization &op) -{ - handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::Input::INPUT)); -} - -void StaticShapeInferer::visit(const ir::operation::MatrixBandPart &op) -{ - handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)); -} - -void StaticShapeInferer::visit(const ir::operation::OneHot &op) -{ - const auto indice_idx{op.getInputs().at(ir::operation::OneHot::Input::INDICES)}; - const auto &indice = _operands.at(indice_idx); - const auto depth_idx{op.getInputs().at(ir::operation::OneHot::Input::DEPTH)}; - const auto &depth = _operands.at(depth_idx); - - const auto axis = op.param().axis; - - auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - if (!depth.isConstant()) - { - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - return; - } - - const auto *depth_buf = reinterpret_cast<const int32_t *>(depth.data()->base()); - assert(depth_buf); - // re-sizing output shape - ir::Shape new_shape = shape_inference::inferOnehotShape(indice.info().shape(), *depth_buf, axis); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Pack &op) -{ - const auto input_idx{op.getInputs().at(0)}; - const auto &input = _operands.at(input_idx); - - // get mutable output operand - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - const auto rank = input.shape().rank() + 1; - const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis); - const auto num = op.param().num; - - assert(0 <= axis && axis < rank); - - // re-sizing output shape - ir::Shape new_shape = shape_inference::inferPackShape(input.info().shape(), axis, rank, num); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Pad &op) -{ - const auto input_idx{op.getInputs().at(ir::operation::Pad::Input::INPUT)}; - const auto &input = _operands.at(input_idx); - - const auto pad_idx{op.getInputs().at(ir::operation::Pad::Input::PAD)}; - const auto &pad = _operands.at(pad_idx); - - // get mutable output operand - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - // if pad is not constant, output also becomes dynamic - if (!pad.isConstant()) - { - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - return; - } - - // re-sizing output shape - const auto new_shape = shape_inference::inferPadShape( - input.shape(), reinterpret_cast<const int32_t *>(pad.data()->base()), - pad.shape().num_elements()); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Permute &op) -{ - const auto input_idx{op.getInputs().at(0)}; - const auto &input = _operands.at(input_idx); - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - // re-sizing output shape - // Permute is a special operation that layouts of input/output may be different on backend - // However, it is not applied here, so input/output have the same layout of frontend. Because - // "ExecutorFactory" would convert shape of input/output accoding to the layouts when registering - // operand info to "TensorBuilder" after calling "StaticShapeInferer" - const auto new_shape = input.info().shape(); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Pow &op) -{ - handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Pow::Input::LHS), - op.getInputs().at(ir::operation::Pow::Input::RHS)); -} - -void StaticShapeInferer::visit(const ir::operation::Range &op) -{ - const auto start_idx{op.getInputs().at(ir::operation::Range::Input::START)}; - const auto limit_idx{op.getInputs().at(ir::operation::Range::Input::LIMIT)}; - const auto delta_idx{op.getInputs().at(ir::operation::Range::Input::DELTA)}; - const auto &start_op = _operands.at(start_idx); - const auto &limit_op = _operands.at(limit_idx); - const auto &delta_op = _operands.at(delta_idx); - - // get mutable output operand - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - ir::Shape new_shape; - if (start_op.isConstant() && limit_op.isConstant() && delta_op.isConstant()) - { - assert(start_op.typeInfo().type() == limit_op.typeInfo().type() && - start_op.typeInfo().type() == delta_op.typeInfo().type()); - if (output.typeInfo().type() == ir::DataType::FLOAT32) - { - new_shape = shape_inference::inferRangeShape<float>( - start_op.asScalar<float>(), limit_op.asScalar<float>(), delta_op.asScalar<float>()); - } - else if (output.typeInfo().type() == ir::DataType::INT32) - { - new_shape = shape_inference::inferRangeShape<int32_t>( - start_op.asScalar<int32_t>(), limit_op.asScalar<int32_t>(), delta_op.asScalar<int32_t>()); - } - assert(output.shape() == new_shape); - } - else - { - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - } -} - -void StaticShapeInferer::visit(const ir::operation::Reduce &op) -{ - const auto input_idx{op.getInputs().at(ir::operation::Reduce::Input::INPUT)}; - const auto &input = _operands.at(input_idx); - - const auto axes_idx{op.getInputs().at(ir::operation::Reduce::Input::AXES)}; - const auto &axes = _operands.at(axes_idx); - - // get mutable output operand - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - std::vector<int32_t> axes_vec; - for (size_t i = 0; i < axes.shape().num_elements(); ++i) - { - switch (axes.typeInfo().type()) - { - case ir::DataType::INT32: - { - axes_vec.emplace_back(reinterpret_cast<const int32_t *>(axes.data()->base())[i]); - break; - } - case ir::DataType::INT64: - { - axes_vec.emplace_back(reinterpret_cast<const int64_t *>(axes.data()->base())[i]); - break; - } - default: - throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported data type"); - break; - } - } - const auto keep_dims = op.param().keep_dims; - - // re-sizing output shape - ir::Shape new_shape = - shape_inference::inferReduceShape(input.info().shape(), axes_vec, keep_dims); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Reshape &op) -{ - const auto input_idx{op.getInputs().at(ir::operation::Reshape::Input::INPUT)}; - const auto &input = _operands.at(input_idx); - - // get mutable output operand - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - // New shape is given by second input tensor - if (op.getInputs().size() == 2) - { - // Let's check the second input - const auto shape_idx{op.getInputs().at(ir::operation::Reshape::Input::SHAPE)}; - const auto &shape = _operands.at(shape_idx); - - if (shape.isConstant()) - { - const auto *shape_buf = reinterpret_cast<const int32_t *>(shape.data()->base()); - assert(shape_buf); - - ir::Shape new_shape = shape_inference::inferReshapeShape( - shape_buf, shape.shape().num_elements(), input.shape().num_elements()); - - // if shape is from Const, TFLC put the shape of output into tensor - if (new_shape != output.shape()) - { - // change on output shape - output.info().shape(new_shape); - } - } - else - { - // if shape is NOT Const, set output shape to be dynamic_ - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - } - } - // New shape is given by option - else if (op.param().new_shape.size() != 0) - { - // Let's check the new_shape option - auto shape = op.param().new_shape; - ir::Shape new_shape = shape_inference::inferReshapeShape(shape.data(), shape.size(), - input.shape().num_elements()); - - if (new_shape != output.shape()) - { - // change on output shape - output.info().shape(new_shape); - } - } - else - { - throw std::runtime_error("Reshape: new shape is missing"); - } -} - -void StaticShapeInferer::visit(const ir::operation::ResizeBilinear &op) -{ - const auto input_idx{op.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)}; - const auto &input = _operands.at(input_idx); - - // get mutable output operand - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - // Shape inferencing logic based on Params - ir::Shape new_shape = shape_inference::inferResizeBilinearShape( - input.shape(), op.param().height_out, op.param().width_out); - - // if size_op is from Const, TFLC put the shape of output into tensor - if (new_shape != output.shape()) - { - // change on output shape - output.info().shape(new_shape); - } -} - -void StaticShapeInferer::visit(const ir::operation::Reverse &op) -{ - handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Reverse::Input::INPUT)); -} - -void StaticShapeInferer::visit(const ir::operation::Select &op) -{ - const auto input_cond_idx{op.getInputs().at(ir::operation::Select::Input::CONDITION)}; - const auto &input_cond = _operands.at(input_cond_idx); - - const auto input_true_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)}; - const auto &input_true = _operands.at(input_true_idx); - - const auto input_false_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)}; - const auto &input_false = _operands.at(input_false_idx); - - auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - // Select output shpae - ir::Shape new_shape = shape_inference::inferSelectShape( - input_cond.info().shape(), input_true.info().shape(), input_false.info().shape()); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Shape &op) -{ - const auto input_idx{op.getInputs().at(0)}; - const auto &input = _operands.at(input_idx); - - // get mutable output operand - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - // re-sizing output shape - ir::Shape output_shape; - output_shape.append(input.info().shape().rank()); - - output.info().shape(output_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Slice &op) -{ - const auto input_index{op.getInputs().at(ir::operation::Slice::Input::INPUT)}; - const auto &input = _operands.at(input_index); - const auto begins_index{op.getInputs().at(ir::operation::Slice::Input::BEGINS)}; - const auto &begins = _operands.at(begins_index); - const auto sizes_index{op.getInputs().at(ir::operation::Slice::Input::SIZES)}; - const auto &sizes = _operands.at(sizes_index); - const auto output_index = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_index); - - // Whether input is constant or not does not affect whether output is dynamic or not - if (!(begins.isConstant() && sizes.isConstant())) - { - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - return; - } - - auto begins_buf = reinterpret_cast<const int32_t *>(begins.data()->base()); - auto sizes_buf = reinterpret_cast<const int32_t *>(sizes.data()->base()); - - ir::Shape new_shape = - shape_inference::inferSliceShape(input.info().shape(), begins_buf, sizes_buf); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Softmax &op) -{ - handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Softmax::Input::INPUT)); -} - -void StaticShapeInferer::visit(const ir::operation::SpaceToBatchND &op) -{ - const auto output_index = op.getOutputs().at(0); - const auto input_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)}; - const auto block_shape_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)}; - const auto padding_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)}; - - ir::Operand &output = _operands.at(output_index); - const auto &input = _operands.at(input_idx); - const auto &block_shape = _operands.at(block_shape_idx); - const auto &padding = _operands.at(padding_idx); - - // Whether input is constant or not does not affect whether output is dynamic or not - if (!(block_shape.isConstant() && padding.isConstant())) - { - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - return; - } - - auto input_shape = input.info().shape(); - auto block_shape_shape = block_shape.info().shape(); - auto padding_shape = padding.info().shape(); - - auto block_shape_data = reinterpret_cast<const int32_t *>(block_shape.data()->base()); - auto padding_data = reinterpret_cast<const int32_t *>(padding.data()->base()); - - ir::Shape new_shape = shape_inference::inferSpaceToBatchNDShape( - input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data); - - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Split &op) -{ - const auto input_idx{op.getInputs().at(0)}; - const auto &input = _operands.at(input_idx); - - const auto axis = op.param().axis; - const auto num_splits = op.param().num_splits; - - const auto rank = input.info().shape().rank(); - auto axis_resolved = axis < 0 ? axis + rank : axis; - - assert(0 <= axis_resolved && axis_resolved < rank); - - ir::Shape new_shape = - shape_inference::inferSplitShape(input.info().shape(), axis_resolved, num_splits); - auto output_tensors = op.getOutputs(); - for (auto output_idx : output_tensors) - { - ir::Operand &output = _operands.at(output_idx); - output.info().shape(new_shape); - } -} - -void StaticShapeInferer::visit(const ir::operation::SquaredDifference &op) -{ - handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::SquaredDifference::Input::LHS), - op.getInputs().at(ir::operation::SquaredDifference::Input::RHS)); -} - -void StaticShapeInferer::visit(const ir::operation::Squeeze &op) -{ - const auto input_idx{op.getInputs().at(ir::operation::Squeeze::Input::INPUT)}; - const auto &input = _operands.at(input_idx); - - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - if (input.info().isDynamic()) - { - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - return; - } - - // Squeeze output shpae - ir::Shape new_shape = shape_inference::inferSqueezeShape(input.info().shape(), op.param()); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::StridedSlice &op) -{ - const auto input_index{op.getInputs().at(ir::operation::StridedSlice::Input::INPUT)}; - const auto &input = _operands.at(input_index); - const auto starts_index{op.getInputs().at(ir::operation::StridedSlice::Input::STARTS)}; - const auto &starts = _operands.at(starts_index); - const auto ends_index{op.getInputs().at(ir::operation::StridedSlice::Input::ENDS)}; - const auto &ends = _operands.at(ends_index); - const auto strides_index{op.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)}; - const auto &strides = _operands.at(strides_index); - const auto output_index = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_index); - - if (!(starts.isConstant() && ends.isConstant() && strides.isConstant())) - { - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - return; - } - - const auto begin_mask = op.param().begin_mask; - const auto end_mask = op.param().end_mask; - const auto shrink_axis_mask = op.param().shrink_axis_mask; - const auto rank = input.info().shape().rank(); - - auto starts_buf = reinterpret_cast<const uint32_t *>(starts.data()->base()); - auto ends_buf = reinterpret_cast<const uint32_t *>(ends.data()->base()); - auto strides_buf = reinterpret_cast<const uint32_t *>(strides.data()->base()); - - auto op_params = shape_inference::buildStridedSliceParams( - starts_buf, ends_buf, strides_buf, begin_mask, end_mask, shrink_axis_mask, rank); - - ir::Shape new_shape = - shape_inference::inferStridedSliceShape(input.info().shape(), op_params, rank); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Tile &op) -{ - const auto input_idx{op.getInputs().at(ir::operation::Tile::Input::INPUT)}; - const auto &input = _operands.at(input_idx); - - const auto multiplier_idx{op.getInputs().at(ir::operation::Tile::Input::MULTIPLES)}; - const auto &multiplier = _operands.at(multiplier_idx); - - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - - if (!multiplier.isConstant()) - { - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - return; - } - - auto multiplier_buffer = reinterpret_cast<const int32_t *>(multiplier.data()->base()); - assert(multiplier_buffer); - - // re-sizing output shape - auto new_shape = shape_inference::inferTileShape(input.info().shape(), multiplier_buffer); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Transpose &op) -{ - const auto input_idx{op.getInputs().at(ir::operation::Transpose::Input::INPUT)}; - const auto &input = _operands.at(input_idx); - - // get mutable output operand - const auto output_idx = op.getOutputs().at(0); - ir::Operand &output = _operands.at(output_idx); - const auto perm{op.param().perm}; - // const auto rank{op.param().rank}; - - // set output shape, based on input and params - ir::Shape new_shape = shape_inference::inferTransposeShape(input.info().shape(), perm); - output.info().shape(new_shape); -} - -void StaticShapeInferer::visit(const ir::operation::Unpack &op) -{ - const auto input_idx{op.getInputs().at(0)}; - const auto &input = _operands.at(input_idx); - const auto num = op.param().num; - const auto rank = input.shape().rank(); - const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis); - - assert(axis < rank); - if (axis < 0) - { - for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++) - { - const auto output_idx = op.getOutputs().at(out_tensor_idx); - ir::Operand &output = _operands.at(output_idx); - output.info().setDynamic(); - } - _return_has_dynamic_tensor = true; - return; - } - - ir::Shape new_shape = shape_inference::inferUnpackShape(input.info().shape(), axis, rank); - - // re-sizing output shape - for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++) - { - const auto output_idx = op.getOutputs().at(out_tensor_idx); - ir::Operand &output = _operands.at(output_idx); - output.info().shape(new_shape); - } -} - -void StaticShapeInferer::visit(const ir::operation::While &op) -{ - auto &cond_graph = _lowered_subgs.at(op.param().cond_subg_index)->graph(); - auto &body_graph = _lowered_subgs.at(op.param().body_subg_index)->graph(); - const auto inputs = op.getInputs(); - const auto &outputs = op.getOutputs(); - - // re-sizing input shapes of then subgraph - const auto &cond_inputs = cond_graph.getInputs(); - assert(inputs.size() == cond_inputs.size()); - for (size_t i = 0; i < inputs.size(); ++i) - { - const auto &input = _operands.at(inputs.at(i)); - auto &cond_input = cond_graph.operands().at(cond_inputs.at(i)); - if (input.info().isDynamic()) - { - cond_input.info().setDynamic(); - } - else - { - auto new_shape = input.info().shape(); - cond_input.info().shape(new_shape); - } - } - - // re-sizing input shapes of body subgraph - const auto &body_inputs = body_graph.getInputs(); - assert(cond_inputs.size() == body_inputs.size()); - for (size_t i = 0; i < cond_inputs.size(); ++i) - { - const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i)); - auto &body_input = body_graph.operands().at(body_inputs.at(i)); - if (cond_input.info().isDynamic()) - { - body_input.info().setDynamic(); - } - else - { - const auto &new_shape = cond_input.info().shape(); - body_input.info().shape(new_shape); - } - } - - // re-sizing operands of body subgraph - StaticShapeInferer body_inferer(op.param().body_subg_index, _lowered_subgs); - _lowered_subgs.at(op.param().body_subg_index) - ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) { - bool has_dynamic_tensor = body_inferer.infer(op_seq); - op_seq.has_dynamic_tensor(has_dynamic_tensor); - }); - - // Check whether while operation's shapes are predictable - // If any of shape of body outputs and cond inputs are different, non-constant operands would be - // set to dynamic - bool check_unpredictable_dynamic = false; - const auto &body_outputs = body_graph.getOutputs(); - assert(body_outputs.size() == cond_inputs.size()); - for (size_t i = 0; i < body_outputs.size(); ++i) - { - const auto &body_output = body_graph.operands().at(body_outputs.at(i)); - auto &cond_input = cond_graph.operands().at(cond_inputs.at(i)); - if ((cond_input.info().isDynamic() != body_output.info().isDynamic()) || - (cond_input.shape() != body_output.shape())) - { - check_unpredictable_dynamic = true; - break; - } - } - - if (check_unpredictable_dynamic) - { - // Set inputs of body subgraph - for (const auto &input_index : body_inputs) - { - auto &input = body_graph.operands().at(input_index); - if (!input.isConstant()) - { - input.info().setDynamic(); - } - } - - // Set inputs of cond subgraph - for (const auto &input_index : cond_inputs) - { - auto &input = cond_graph.operands().at(input_index); - if (!input.isConstant()) - { - input.info().setDynamic(); - } - } - - // Set non-constant operands of body subgraph to dynamic - StaticShapeInferer body_inferer(op.param().body_subg_index, _lowered_subgs); - _lowered_subgs.at(op.param().body_subg_index) - ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) { - bool has_dynamic_tensor = body_inferer.infer(op_seq); - op_seq.has_dynamic_tensor(has_dynamic_tensor); - }); - } - - // re-sizing operands of cond subgraph - // If check_unpredictable_dynamic is true, non-constant operands of cond subgraph would be set to - // dynamic - StaticShapeInferer cond_inferer(op.param().cond_subg_index, _lowered_subgs); - _lowered_subgs.at(op.param().cond_subg_index) - ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) { - bool has_dynamic_tensor = cond_inferer.infer(op_seq); - op_seq.has_dynamic_tensor(has_dynamic_tensor); - }); - - // re-sizing outputs of while operation - // If check_unpredictable_dynamic is true, outputs of while operation would be set to dynamic - assert(cond_inputs.size() == outputs.size()); - for (size_t i = 0; i < cond_inputs.size(); ++i) - { - const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i)); - auto &output = _operands.at(outputs.at(i)); - if (cond_input.info().isDynamic()) - { - output.info().setDynamic(); - _return_has_dynamic_tensor = true; - } - else - { - const auto new_shape = cond_input.info().shape(); - output.info().shape(new_shape); - } - } -} - -} // namespace compiler - -} // namespace onert diff --git a/runtime/onert/core/src/compiler/StaticShapeInferer.cc b/runtime/onert/core/src/compiler/StaticShapeInferer.cc new file mode 100644 index 000000000..622edbab4 --- /dev/null +++ b/runtime/onert/core/src/compiler/StaticShapeInferer.cc @@ -0,0 +1,1487 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compiler/StaticShapeInferer.h" +#include "util/ShapeInference.h" +#include "util/logging.h" + +#include <misc/polymorphic_downcast.h> + +#include <sstream> +#include <stdexcept> + +namespace onert +{ +namespace compiler +{ +void OperandObserver::updateShapes(const std::vector<ir::OperandInfo> &changed_operands_info, + bool unpredictable) +{ + assert(changed_operands_info.size() == _operands.size()); + for (size_t i = 0; i < changed_operands_info.size(); ++i) + { + const auto &changed_operand_info = changed_operands_info.at(i); + auto &operand = _operands.at(i); + // assert(changed_operand_info.typeInfo() == operand->typeInfo()); + // assert(changed_operand_info.typeInfo() == operand->typeInfo()); + // This error check may by replaced by an assertion if this function is called after the + // validation of models are completed. + if (changed_operand_info.typeInfo() != operand->typeInfo()) + { + throw std::runtime_error("OperandObserver: The types of operands are mismatched"); + } + if (!operand->info().isConstant() && (changed_operand_info.isDynamic() || unpredictable)) + { + operand->info().setDynamic(); + } + else + { + const auto &new_shape = changed_operands_info.at(i).shape(); + operand->info().shape(new_shape); + } + } +} + +void StaticShapeInferer::infer() +{ + for (const auto &op_idx : _lowered_subg->graph().topolSortOperations()) + { + const auto &op = _lowered_subg->graph().operations().at(op_idx); + bool has_dynamic_tensor = false; + const auto opcode = op.opcode(); + // IF: requires shape inference for then, else + // While: requires shape inference for condition, body + if (opcode == ir::OpCode::If || opcode == ir::OpCode::While) + { + op.accept(*this); + } + else + { + has_dynamic_tensor = checkDynamicInput(op); + if (has_dynamic_tensor) + { + setDynamicOutput(op); + } + else + { + op.accept(*this); + } + } + has_dynamic_tensor = has_dynamic_tensor || checkDynamicOutput(op); + _lowered_subg->setHasDynamicTensor(op_idx, has_dynamic_tensor); + } + + if (_controlflow_output_observer != nullptr) + { + // re-sizing output shapes of the controflow operation branching to this subgraph + std::vector<ir::OperandInfo> outputs_info; + const auto &graph = _lowered_subg->graph(); + const auto &outputs = graph.getOutputs(); + for (size_t i = 0; i < outputs.size(); ++i) + { + const auto &operand_info = graph.operands().at(outputs.at(i)).info(); + outputs_info.emplace_back(operand_info); + } + _controlflow_output_observer->updateShapes(outputs_info); + } +} + +bool StaticShapeInferer::checkDynamicInput(const ir::IOperation &op) +{ + const auto &operands = _lowered_subg->graph().operands(); + for (auto &&input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED) + { + if (operands.at(input_idx).info().isDynamic()) + { + return true; + } + } + + return false; +} + +bool StaticShapeInferer::checkDynamicOutput(const ir::IOperation &op) +{ + auto &operands = _lowered_subg->graph().operands(); + for (auto &&output_idx : op.getOutputs() | ir::Remove::UNDEFINED) + { + if (operands.at(output_idx).info().isDynamic()) + { + return true; + } + } + return false; +} + +void StaticShapeInferer::setDynamicOutput(const ir::IOperation &op) +{ + auto &operands = _lowered_subg->graph().operands(); + for (auto &&output_idx : op.getOutputs() | ir::Remove::UNDEFINED) + { + operands.at(output_idx).info().setDynamic(); + } +} + +void StaticShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op, + const ir::OperandIndex lhs_idx, + const ir::OperandIndex rhs_idx) +{ + auto &operands = _lowered_subg->graph().operands(); + const auto &lhs = operands.at(lhs_idx); + const auto &rhs = operands.at(rhs_idx); + + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + // re-sizing output shape + ir::Shape new_shape = shape_inference::inferEltwiseShape(lhs.info().shape(), rhs.info().shape()); + output.info().shape(new_shape); +} + +void StaticShapeInferer::handleSimpleUnaryOp(const ir::Operation &op, + const ir::OperandIndex input_idx) +{ + auto &operands = _lowered_subg->graph().operands(); + const auto &input = operands.at(input_idx); + + // get mutable output operand + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + // re-sizing output shape + ir::Shape new_shape = input.info().shape(); + output.info().shape(new_shape); +} + +void StaticShapeInferer::dump() +{ + auto get_shape_str = [](const ir::Shape &shape) { + std::stringstream sstream; + sstream << "shape : {"; + for (int i = 0; i < shape.rank(); i++) + { + if (i == 0) + sstream << shape.dim(i); + else + sstream << " " << shape.dim(i); + } + sstream << "}"; + return sstream.str(); + }; + + _lowered_subg->graph().operands().iterate( + [&](const ir::OperandIndex &ind, const ir::Operand &operand) { + VERBOSE(StaticShapeInferer) << " " << ind << ", " + << (operand.info().isDynamic() ? "Dynamic" : "Static") << ", " + << get_shape_str(operand.info().shape()) << std::endl; + }); +} + +std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> +StaticShapeInferer::createStaticShapeInferers( + const std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> &lowered_subgs) +{ + // Allocate StaticShapeInferer per each subgraph + std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers; + for (auto &&pair : lowered_subgs) + { + const auto &subg_index = pair.first; + auto &lowered_subg = pair.second; + inferers[subg_index] = std::make_unique<StaticShapeInferer>(lowered_subg); + } + + // Append observers in all StaticShapeInferers + for (auto &&pair : lowered_subgs) + { + const auto &subg_index = pair.first; + auto &lowered_subg = pair.second; + + // TODO: Change this iteration for all to controlflow iteration + lowered_subg->graph().operations().iterate( + [&](const ir::OperationIndex &, const ir::IOperation &op) { + // A Function to append child inferers. These make it possible for a StaticShapeInferer to + // call StaticShapeInferes of child subgraphs recursively + auto appendChildInferer = [&](const ir::SubgraphIndex &child_subg_idx) { + auto *child_inferer = inferers.at(child_subg_idx).get(); + inferers.at(subg_index)->appendChildInferer(child_subg_idx, child_inferer); + }; + + // A Function to appaend subg input observers. This makes it possible for a + // StaticShapeInferer to update inputs of child subgraphs + auto appendSubgraphInputObserver = [&](const ir::SubgraphIndex &child_subg_idx) { + std::vector<ir::Operand *> child_subg_inputs; + auto &child_subg = lowered_subgs.at(child_subg_idx)->graph(); + for (const auto &input_idx : child_subg.getInputs()) + { + auto operand_ptr = child_subg.operands().getRawPtr(input_idx); + child_subg_inputs.emplace_back(operand_ptr); + } + inferers.at(subg_index) + ->appendSubgInputObserver(child_subg_idx, + std::make_unique<OperandObserver>(child_subg_inputs)); + }; + + // A Function to set controlflow output observers. This makes it possible for a + // StaticShapeInferer to update outputs of parent controlflow opeerations + auto setControlFlowOutputObserver = [&](const ir::SubgraphIndex &child_subg_idx) { + std::vector<ir::Operand *> cf_outputs; + auto &subg = lowered_subg->graph(); + for (const auto &output_idx : op.getOutputs()) + { + auto operand_ptr = subg.operands().getRawPtr(output_idx); + cf_outputs.emplace_back(operand_ptr); + } + inferers.at(child_subg_idx) + ->setControlflowOutputObserver(std::make_unique<OperandObserver>(cf_outputs)); + }; + + // Append Observers in a StaticShapeInferer + if (op.opcode() == ir::OpCode::If) + { + // TODO Remove dynamic_cast + // An virtual base class cannot be downcasted by static_cast + try + { + const auto &if_op = dynamic_cast<const ir::operation::If &>(op); + + appendChildInferer(if_op.param().then_subg_index); + appendChildInferer(if_op.param().else_subg_index); + + appendSubgraphInputObserver(if_op.param().then_subg_index); + appendSubgraphInputObserver(if_op.param().else_subg_index); + + setControlFlowOutputObserver(if_op.param().then_subg_index); + } + catch (const std::bad_cast &) + { + throw std::runtime_error("StaticShapeInferer: Invalid If operation"); + } + } + else if (op.opcode() == ir::OpCode::While) + { + // TODO Remove dynamic_cast + try + { + const auto &while_op = dynamic_cast<const ir::operation::While &>(op); + + appendChildInferer(while_op.param().cond_subg_index); + appendChildInferer(while_op.param().body_subg_index); + + appendSubgraphInputObserver(while_op.param().cond_subg_index); + appendSubgraphInputObserver(while_op.param().body_subg_index); + + setControlFlowOutputObserver(while_op.param().body_subg_index); + } + catch (const std::bad_cast &) + { + throw std::runtime_error("StaticShapeInferer: Invalid While operation"); + } + } + }); + } + + return inferers; +} + +void StaticShapeInferer::visit(const ir::operation::ArgMinMax &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::INPUT)}; + const auto &input = operands.at(input_idx); + + const auto axis_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::AXIS)}; + const auto &axis = operands.at(axis_idx); + + // get mutable output operand + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + if (!axis.isConstant()) + { + output.info().setDynamic(); + return; + } + + const auto rank = input.info().shape().rank(); + auto axis_value = axis.asScalar<int32_t>(); + axis_value = axis_value < 0 ? axis_value + rank : axis_value; + + // re-sizing output shape + ir::Shape new_shape = + shape_inference::inferArgMinMaxShape(input.info().shape(), axis_value, rank); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::BatchMatMul &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto lhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::LHS); + const auto rhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::RHS); + const auto output_index = op.getOutputs().at(0); + const auto &lhs = operands.at(lhs_index); + const auto &rhs = operands.at(rhs_index); + auto &output = operands.at(output_index); + auto new_shape = shape_inference::inferBatchMatMulShape(lhs.shape(), rhs.shape(), op.param()); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::BCQFullyConnected &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)}; + const auto &input = operands.at(input_idx); + + const auto cluster_idx{ + op.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)}; + const auto &cluster = operands.at(cluster_idx); + + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base()); + assert(cluster_buf); + + // re-sizing output shape + ir::Shape new_shape = shape_inference::inferBCQFullyConnectedShape( + input.info().shape(), cluster.info().shape(), cluster_buf); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::BCQGather &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto indices_idx{op.getInputs().at(ir::operation::BCQGather::Input::INDICES)}; + const auto &indices = operands.at(indices_idx); + + const auto input_binary_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)}; + const auto &input_binary = operands.at(input_binary_idx); + + const auto cluster_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)}; + const auto &cluster = operands.at(cluster_idx); + + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base()); + assert(cluster_buf); + + auto rank = input_binary.shape().rank(); + + // re-sizing output shape + ir::Shape new_shape = shape_inference::inferBCQGatherShape( + indices.info().shape(), cluster.info().shape(), cluster_buf, rank, op.param()); + + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::BinaryArithmetic &op) +{ + handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::BinaryArithmetic::Input::LHS), + op.getInputs().at(ir::operation::BinaryArithmetic::Input::RHS)); +} + +void StaticShapeInferer::visit(const ir::operation::BroadcastTo &op) +{ + // get mutable output operand + auto &operands = _lowered_subg->graph().operands(); + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + const auto shape_idx{op.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)}; + const auto &shape = operands.at(shape_idx); + + if (!shape.isConstant()) + { + output.info().setDynamic(); + return; + } + + // assert(shape.typeInfo().type() == ir::DataType::INT32); + auto shape_buffer = reinterpret_cast<const int32_t *>(shape.data()->base()); + + // re-sizing output shape + ir::Shape new_shape = shape_inference::inferBroadcastToShape(shape.info().shape(), shape_buffer); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Comparison &op) +{ + handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Comparison::Input::INPUT0), + op.getInputs().at(ir::operation::Comparison::Input::INPUT1)); +} + +void StaticShapeInferer::visit(const ir::operation::Concat &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_count = op.getInputs().size(); + + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + shape_inference::Shapes input_shapes; + for (uint32_t i = 0; i < input_count; i++) + { + const auto input_idx{op.getInputs().at(i)}; + const auto &input = operands.at(input_idx); + input_shapes.emplace_back(input.shape()); + } + + ir::Shape out_shape = shape_inference::inferConcatShape(input_shapes, op.param()); + + // re-sizing output shape + output.info().shape(out_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Conv2D &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::Conv2D::Input::INPUT)}; + const auto &input = operands.at(input_idx); + const auto ker_idx{op.getInputs().at(ir::operation::Conv2D::Input::KERNEL)}; + const auto &ker = operands.at(ker_idx); + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + // re-sizing output shape + ir::Shape new_shape = + shape_inference::inferConv2DShape(input.info().shape(), ker.info().shape(), op.param()); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::DepthwiseConv2D &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::DepthwiseConv2D::Input::INPUT)}; + const auto &input = operands.at(input_idx); + const auto ker_idx{op.getInputs().at(ir::operation::DepthwiseConv2D::Input::KERNEL)}; + const auto &ker = operands.at(ker_idx); + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + // re-sizing output shape + ir::Shape new_shape = shape_inference::inferDepthwiseConv2DShape(input.info().shape(), + ker.info().shape(), op.param()); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::ElementwiseActivation &op) +{ + handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseActivation::Input::INPUT)); +} + +void StaticShapeInferer::visit(const ir::operation::ElementwiseBinary &op) +{ + handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::ElementwiseBinary::Input::LHS), + op.getInputs().at(ir::operation::ElementwiseBinary::Input::RHS)); +} + +void StaticShapeInferer::visit(const ir::operation::ElementwiseUnary &op) +{ + handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)); +} + +void StaticShapeInferer::visit(const ir::operation::ExpandDims &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::ExpandDims::Input::INPUT)}; + const auto &input = operands.at(input_idx); + const auto axis_idx{op.getInputs().at(ir::operation::ExpandDims::Input::AXIS)}; + const auto &axis = operands.at(axis_idx); + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + if (!axis.isConstant()) + { + output.info().setDynamic(); + return; + } + + // even when axis is constant, output shape should be recalculated since user might call + // nnfw_set_input_tensorinfo(input, some_new_shape) + auto axis_type = axis.typeInfo().type(); + assert(axis_type == ir::DataType::INT32 || axis_type == ir::DataType::INT64); + + assert(axis.data()->base()); + int32_t axis_value = + (axis_type == ir::DataType::INT32) + ? reinterpret_cast<const int32_t *>(axis.data()->base())[0] + : static_cast<int32_t>(reinterpret_cast<const int64_t *>(axis.data()->base())[0]); + + // re-sizing output shape + ir::Shape new_shape = shape_inference::inferExpandDimsShape(input.info().shape(), axis_value); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Fill &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto shape_idx{op.getInputs().at(ir::operation::Fill::Input::SHAPE)}; + const auto &shape = operands.at(shape_idx); + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + if (!shape.isConstant()) + { + output.info().setDynamic(); + return; + } + + const auto dims_type = shape.typeInfo().type(); + assert(dims_type == ir::DataType::INT32 || dims_type == ir::DataType::INT64); + + auto dims_buf = shape.data()->base(); + assert(dims_buf); + + const auto &dims_shape = shape.info().shape(); + const auto &new_shape = ((dims_type == ir::DataType::INT32) + ? shape_inference::inferFillShape<int32_t>( + dims_shape, reinterpret_cast<const int32_t *>(dims_buf)) + : shape_inference::inferFillShape<int64_t>( + dims_shape, reinterpret_cast<const int64_t *>(dims_buf))); + + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::FullyConnected &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::FullyConnected::Input::INPUT)}; + const auto &input = operands.at(input_idx); + + const auto ker_idx{op.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)}; + const auto &ker = operands.at(ker_idx); + + // get mutable output operand + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + // re-sizing output shape + ir::Shape new_shape = + shape_inference::inferFullyConnectedShape(input.info().shape(), ker.info().shape()); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::FusedBatchNorm &op) +{ + handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::FusedBatchNorm::Input::INPUT)); +} + +void StaticShapeInferer::visit(const ir::operation::Gather &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::Gather::Input::INPUT)}; + const auto &input = operands.at(input_idx); + + // get mutable output operand + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + const auto indices_idx{op.getInputs().at(ir::operation::Gather::Input::INDICES)}; + const auto &indices = operands.at(indices_idx); + const auto rank = input.info().shape().rank(); + const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis); + + assert(0 <= axis && axis < rank); + + // re-sizing output shape + ir::Shape new_shape = + shape_inference::inferGatherShape(input.info().shape(), indices.info().shape(), axis, rank); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::If &op) +{ + // re-sizing input shapes of then/else subgraph + const std::vector<ir::OperandIndex> inputs{op.getInputs().begin() + 1, op.getInputs().end()}; + + std::vector<ir::OperandInfo> inputs_info; + const auto &graph = _lowered_subg->graph(); + for (size_t i = 0; i < inputs.size(); ++i) + { + const auto &operand_info = graph.operands().at(inputs.at(i)).info(); + inputs_info.emplace_back(operand_info); + } + _subg_input_observers.at(op.param().then_subg_index)->updateShapes(inputs_info); + _child_inferers.at(op.param().then_subg_index)->infer(); + + _subg_input_observers.at(op.param().else_subg_index)->updateShapes(inputs_info); + _child_inferers.at(op.param().else_subg_index)->infer(); +} + +void StaticShapeInferer::visit(const ir::operation::L2Normalization &op) +{ + handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::Input::INPUT)); +} + +void StaticShapeInferer::visit(const ir::operation::Loss &op) +{ + // TODO Consider SparseCategoricalCrossentropy case + + auto &operands = _lowered_subg->graph().operands(); + + const auto input_index{op.getInputs().at(ir::operation::Loss::Input::Y_PRED)}; + auto &input = operands.at(input_index); + + const auto output_index{op.getOutputs().at(0)}; + auto &output = operands.at(output_index); + + ir::Shape new_shape = output.info().shape(); + new_shape.dim(0) = input.info().shape().dim(0); + + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::LSTM &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto output_index{op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)}; + auto &output = operands.at(output_index); + + const auto output_state_out_index{ + op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; + + const auto cell_state_out_index{op.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; + + const auto scratch_buffer_index{op.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; + + if (output.info().isDynamic() || + (operands.exist(output_state_out_index) && + operands.at(output_state_out_index).info().isDynamic()) || + (operands.exist(cell_state_out_index) && + operands.at(cell_state_out_index).info().isDynamic()) || + (operands.exist(scratch_buffer_index) && + operands.at(scratch_buffer_index).info().isDynamic())) + return; + + const auto input_index{op.getInputs().at(ir::operation::LSTM::Input::INPUT)}; + const auto &input = operands.at(input_index); + + const auto input_to_output_weights_index{ + op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)}; + const auto &input_to_output_weights = operands.at(input_to_output_weights_index); + + const auto recurrent_to_output_weights_index{ + op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)}; + const auto &recurrent_to_output_weights = operands.at(recurrent_to_output_weights_index); + + // re-sizing outputs + const int n_batch = (input.shape().rank() == 3 && op.param().time_major) ? input.shape().dim(1) + : input.shape().dim(0); + const int n_cell = input_to_output_weights.shape().dim(0); + const int n_output = recurrent_to_output_weights.shape().dim(1); + if (input.shape().rank() == 3) + { + if (op.param().time_major) + output.info().shape(ir::Shape{input.shape().dim(0), n_batch, n_output}); + else + output.info().shape(ir::Shape{n_batch, input.shape().dim(1), n_output}); + } + else + { + assert(input.shape().rank() == 2); + output.info().shape(ir::Shape{n_batch, n_output}); + } + + if (operands.exist(output_state_out_index)) + { + auto &output_state_out = operands.at(output_state_out_index); + output_state_out.info().shape(ir::Shape{n_batch, n_output}); + } + + if (operands.exist(cell_state_out_index)) + { + auto &cell_state_out = operands.at(cell_state_out_index); + cell_state_out.info().shape(ir::Shape{n_batch, n_cell}); + } + + if (operands.exist(scratch_buffer_index)) + { + auto &scratch_buffer = operands.at(scratch_buffer_index); + + const auto input_to_input_weights_index{ + op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; + const auto recurrent_to_input_weights_index{ + op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; + + bool has_input_to_input_weights = + operands.at(input_to_input_weights_index).shape().dim(0) != 0 && + operands.at(input_to_input_weights_index).shape().dim(1) != 0; + bool has_recurrent_to_input_weights = + operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 && + operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0; + + // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG). + // true: no CIFG + // false: CIFG + bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights; + if (has_cifg_param) + { + scratch_buffer.info().shape(ir::Shape{n_batch, n_cell * 4}); + } + else + { + scratch_buffer.info().shape(ir::Shape{n_batch, n_cell * 3}); + } + } +} + +void StaticShapeInferer::visit(const ir::operation::MatrixBandPart &op) +{ + handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)); +} + +void StaticShapeInferer::visit(const ir::operation::OneHot &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto indice_idx{op.getInputs().at(ir::operation::OneHot::Input::INDICES)}; + const auto &indice = operands.at(indice_idx); + const auto depth_idx{op.getInputs().at(ir::operation::OneHot::Input::DEPTH)}; + const auto &depth = operands.at(depth_idx); + + const auto axis = op.param().axis; + + auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + if (!depth.isConstant()) + { + output.info().setDynamic(); + return; + } + + const auto *depth_buf = reinterpret_cast<const int32_t *>(depth.data()->base()); + assert(depth_buf); + // re-sizing output shape + ir::Shape new_shape = shape_inference::inferOnehotShape(indice.info().shape(), *depth_buf, axis); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Pack &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(0)}; + const auto &input = operands.at(input_idx); + + // get mutable output operand + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + const auto rank = input.shape().rank() + 1; + const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis); + const auto num = op.param().num; + + assert(0 <= axis && axis < rank); + + // re-sizing output shape + ir::Shape new_shape = shape_inference::inferPackShape(input.info().shape(), axis, rank, num); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Pad &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::Pad::Input::INPUT)}; + const auto &input = operands.at(input_idx); + + const auto pad_idx{op.getInputs().at(ir::operation::Pad::Input::PAD)}; + const auto &pad = operands.at(pad_idx); + + // get mutable output operand + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + // if pad is not constant, output also becomes dynamic + if (!pad.isConstant()) + { + output.info().setDynamic(); + return; + } + + // re-sizing output shape + const auto &new_shape = shape_inference::inferPadShape( + input.shape(), reinterpret_cast<const int32_t *>(pad.data()->base()), + pad.shape().num_elements()); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Permute &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(0)}; + const auto &input = operands.at(input_idx); + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + // re-sizing output shape + // Permute is a special operation that layouts of input/output may be different on backend + // However, it is not applied here, so input/output have the same layout of frontend. Because + // "ExecutorFactory" would convert shape of input/output accoding to the layouts when registering + // operand info to "TensorBuilder" after calling "StaticShapeInferer" + const auto &new_shape = input.info().shape(); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Pool2D &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto layout = _lowered_subg->graph().layout(); + + const auto input_idx{op.getInputs().at(ir::operation::Pool2D::Input::INPUT)}; + const auto &input = operands.at(input_idx); + if (input.info().shape().rank() != 4) + { + throw std::runtime_error(op.name() + ": supports only 4D tensor as input"); + } + + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + ir::Shape new_shape = shape_inference::inferPoolShape(input.info().shape(), op.param(), layout); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Pow &op) +{ + handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Pow::Input::LHS), + op.getInputs().at(ir::operation::Pow::Input::RHS)); +} + +void StaticShapeInferer::visit(const ir::operation::Range &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto start_idx{op.getInputs().at(ir::operation::Range::Input::START)}; + const auto limit_idx{op.getInputs().at(ir::operation::Range::Input::LIMIT)}; + const auto delta_idx{op.getInputs().at(ir::operation::Range::Input::DELTA)}; + const auto &start_op = operands.at(start_idx); + const auto &limit_op = operands.at(limit_idx); + const auto &delta_op = operands.at(delta_idx); + + // get mutable output operand + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + ir::Shape new_shape; + if (start_op.isConstant() && limit_op.isConstant() && delta_op.isConstant()) + { + assert(start_op.typeInfo().type() == limit_op.typeInfo().type() && + start_op.typeInfo().type() == delta_op.typeInfo().type()); + if (output.typeInfo().type() == ir::DataType::FLOAT32) + { + new_shape = shape_inference::inferRangeShape<float>( + start_op.asScalar<float>(), limit_op.asScalar<float>(), delta_op.asScalar<float>()); + } + else if (output.typeInfo().type() == ir::DataType::INT32) + { + new_shape = shape_inference::inferRangeShape<int32_t>( + start_op.asScalar<int32_t>(), limit_op.asScalar<int32_t>(), delta_op.asScalar<int32_t>()); + } + assert(output.shape() == new_shape); + } + else + { + output.info().setDynamic(); + } +} + +void StaticShapeInferer::visit(const ir::operation::Reduce &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::Reduce::Input::INPUT)}; + const auto &input = operands.at(input_idx); + + const auto axes_idx{op.getInputs().at(ir::operation::Reduce::Input::AXES)}; + const auto &axes = operands.at(axes_idx); + + // get mutable output operand + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + std::vector<int32_t> axes_vec; + for (size_t i = 0; i < axes.shape().num_elements(); ++i) + { + switch (axes.typeInfo().type()) + { + case ir::DataType::INT32: + { + axes_vec.emplace_back(reinterpret_cast<const int32_t *>(axes.data()->base())[i]); + break; + } + case ir::DataType::INT64: + { + axes_vec.emplace_back(reinterpret_cast<const int64_t *>(axes.data()->base())[i]); + break; + } + default: + throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported data type"); + break; + } + } + const auto keep_dims = op.param().keep_dims; + + // re-sizing output shape + ir::Shape new_shape = + shape_inference::inferReduceShape(input.info().shape(), axes_vec, keep_dims); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Reshape &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::Reshape::Input::INPUT)}; + const auto &input = operands.at(input_idx); + + // get mutable output operand + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + // New shape is given by second input tensor + if (op.getInputs().size() == 2) + { + // Let's check the second input + const auto shape_idx{op.getInputs().at(ir::operation::Reshape::Input::SHAPE)}; + const auto &shape = operands.at(shape_idx); + + if (shape.isConstant()) + { + const auto *shape_buf = reinterpret_cast<const int32_t *>(shape.data()->base()); + assert(shape_buf); + + ir::Shape new_shape = + shape_inference::inferReshapeShape(input.shape(), shape_buf, shape.shape().num_elements()); + + // if shape is from Const, TFLC put the shape of output into tensor + if (new_shape != output.shape()) + { + // change on output shape + output.info().shape(new_shape); + } + } + else + { + // if shape is NOT Const, set output shape to be dynamic_ + output.info().setDynamic(); + } + } + // New shape is given by option + else if (op.param().new_shape.size() != 0) + { + // Let's check the new_shape option + auto shape = op.param().new_shape; + ir::Shape new_shape = + shape_inference::inferReshapeShape(input.shape(), shape.data(), shape.size()); + + if (new_shape != output.shape()) + { + // change on output shape + output.info().shape(new_shape); + } + } + else + { + throw std::runtime_error("Reshape: new shape is missing"); + } +} + +void StaticShapeInferer::visit(const ir::operation::ResizeBilinear &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)}; + const auto &input = operands.at(input_idx); + + // get mutable output operand + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + int32_t height_out, width_out; + if (op.getInputs().size() == 2) + { + auto &size = operands.at(op.getInputs().at(ir::operation::ResizeBilinear::Input::SIZE)); + if (!size.isConstant()) + { + output.info().setDynamic(); + return; + } + const auto size_v = size.asVector<std::int32_t>(); + height_out = size_v[0]; + width_out = size_v[1]; + } + else + { + height_out = op.param().height_out; + width_out = op.param().width_out; + } + + // Shape inferencing logic based on Params + ir::Shape new_shape = + shape_inference::inferResizeBilinearShape(input.shape(), height_out, width_out); + + // if size_op is from Const, TFLC put the shape of output into tensor + if (new_shape != output.shape()) + { + // change on output shape + output.info().shape(new_shape); + } +} + +void StaticShapeInferer::visit(const ir::operation::Reverse &op) +{ + handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Reverse::Input::INPUT)); +} + +void StaticShapeInferer::visit(const ir::operation::Select &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_cond_idx{op.getInputs().at(ir::operation::Select::Input::CONDITION)}; + const auto &input_cond = operands.at(input_cond_idx); + + const auto input_true_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)}; + const auto &input_true = operands.at(input_true_idx); + + const auto input_false_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)}; + const auto &input_false = operands.at(input_false_idx); + + auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + // Select output shpae + ir::Shape new_shape = shape_inference::inferSelectShape( + input_cond.info().shape(), input_true.info().shape(), input_false.info().shape()); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Shape &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(0)}; + const auto &input = operands.at(input_idx); + + // get mutable output operand + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + // re-sizing output shape + ir::Shape output_shape; + output_shape.append(input.info().shape().rank()); + + output.info().shape(output_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Slice &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_index{op.getInputs().at(ir::operation::Slice::Input::INPUT)}; + const auto &input = operands.at(input_index); + const auto begins_index{op.getInputs().at(ir::operation::Slice::Input::BEGINS)}; + const auto &begins = operands.at(begins_index); + const auto sizes_index{op.getInputs().at(ir::operation::Slice::Input::SIZES)}; + const auto &sizes = operands.at(sizes_index); + const auto output_index = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_index); + + // Whether input is constant or not does not affect whether output is dynamic or not + if (!(begins.isConstant() && sizes.isConstant())) + { + output.info().setDynamic(); + return; + } + + auto begins_buf = begins.data()->base(); + auto sizes_buf = sizes.data()->base(); + + const auto begins_type = begins.typeInfo().type(); + assert(begins_type == ir::DataType::INT32 || begins_type == ir::DataType::INT64); + assert(begins_type == sizes.typeInfo().type()); + + ir::Shape new_shape = + (begins_type == ir::DataType::INT32) + ? shape_inference::inferSliceShape<int32_t>(input.info().shape(), + reinterpret_cast<const int32_t *>(begins_buf), + reinterpret_cast<const int32_t *>(sizes_buf)) + : shape_inference::inferSliceShape<int64_t>(input.info().shape(), + reinterpret_cast<const int64_t *>(begins_buf), + reinterpret_cast<const int64_t *>(sizes_buf)); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Softmax &op) +{ + handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Softmax::Input::INPUT)); +} + +void StaticShapeInferer::visit(const ir::operation::SpaceToBatchND &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto output_index = op.getOutputs().at(0); + const auto input_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)}; + const auto &block_shape_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)}; + const auto &padding_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)}; + + ir::Operand &output = operands.at(output_index); + const auto &input = operands.at(input_idx); + const auto &block_shape = operands.at(block_shape_idx); + const auto &padding = operands.at(padding_idx); + + // Whether input is constant or not does not affect whether output is dynamic or not + if (!(block_shape.isConstant() && padding.isConstant())) + { + output.info().setDynamic(); + return; + } + + const auto &input_shape = input.info().shape(); + const auto &block_shape_shape = block_shape.info().shape(); + const auto &padding_shape = padding.info().shape(); + + auto block_shape_data = reinterpret_cast<const int32_t *>(block_shape.data()->base()); + auto padding_data = reinterpret_cast<const int32_t *>(padding.data()->base()); + + ir::Shape new_shape = shape_inference::inferSpaceToBatchNDShape( + input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data); + + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Split &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::Split::Input::INPUT)}; + const auto &input = operands.at(input_idx); + + const auto axis_idx{op.getInputs().at(ir::operation::Split::Input::AXIS)}; + const auto &axis = operands.at(axis_idx); + + auto outputs = op.getOutputs(); + if (!axis.isConstant()) + { + for (auto &&output_idx : outputs) + { + ir::Operand &output = operands.at(output_idx); + output.info().setDynamic(); + } + return; + } + + const auto num_splits = op.param().num_splits; + + const auto rank = input.info().shape().rank(); + auto axis_value = axis.asScalar<int32_t>(); + axis_value = axis_value < 0 ? axis_value + rank : axis_value; + + assert(0 <= axis_value && axis_value < rank); + + ir::Shape new_shape = + shape_inference::inferSplitShape(input.info().shape(), axis_value, num_splits); + for (auto &&output_idx : outputs) + { + ir::Operand &output = operands.at(output_idx); + output.info().shape(new_shape); + } +} + +void StaticShapeInferer::visit(const ir::operation::SquaredDifference &op) +{ + handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::SquaredDifference::Input::LHS), + op.getInputs().at(ir::operation::SquaredDifference::Input::RHS)); +} + +void StaticShapeInferer::visit(const ir::operation::Squeeze &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::Squeeze::Input::INPUT)}; + const auto &input = operands.at(input_idx); + + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + // Squeeze output shpae + ir::Shape new_shape = shape_inference::inferSqueezeShape(input.info().shape(), op.param()); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::StridedSlice &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_index{op.getInputs().at(ir::operation::StridedSlice::Input::INPUT)}; + const auto &input = operands.at(input_index); + const auto starts_index{op.getInputs().at(ir::operation::StridedSlice::Input::STARTS)}; + const auto &starts = operands.at(starts_index); + const auto ends_index{op.getInputs().at(ir::operation::StridedSlice::Input::ENDS)}; + const auto &ends = operands.at(ends_index); + const auto strides_index{op.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)}; + const auto &strides = operands.at(strides_index); + const auto output_index = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_index); + + if (!(starts.isConstant() && ends.isConstant() && strides.isConstant())) + { + output.info().setDynamic(); + return; + } + + const auto begin_mask = op.param().begin_mask; + const auto end_mask = op.param().end_mask; + const auto shrink_axis_mask = op.param().shrink_axis_mask; + const auto rank = input.info().shape().rank(); + + auto starts_buf = reinterpret_cast<const uint32_t *>(starts.data()->base()); + auto ends_buf = reinterpret_cast<const uint32_t *>(ends.data()->base()); + auto strides_buf = reinterpret_cast<const uint32_t *>(strides.data()->base()); + + auto op_params = shape_inference::buildStridedSliceParams( + starts_buf, ends_buf, strides_buf, begin_mask, end_mask, shrink_axis_mask, rank); + + ir::Shape new_shape = + shape_inference::inferStridedSliceShape(input.info().shape(), op_params, rank); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Tile &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::Tile::Input::INPUT)}; + const auto &input = operands.at(input_idx); + + const auto multiplier_idx{op.getInputs().at(ir::operation::Tile::Input::MULTIPLES)}; + const auto &multiplier = operands.at(multiplier_idx); + + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + if (!multiplier.isConstant()) + { + output.info().setDynamic(); + return; + } + + auto multiplier_buffer = reinterpret_cast<const int32_t *>(multiplier.data()->base()); + assert(multiplier_buffer); + + // re-sizing output shape + auto new_shape = shape_inference::inferTileShape(input.info().shape(), multiplier_buffer, + multiplier.shape().num_elements()); + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Transpose &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(ir::operation::Transpose::Input::INPUT)}; + const auto &input = operands.at(input_idx); + + const auto perm_idx{op.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)}; + const auto &perm = operands.at(perm_idx); + + // perm.shape() != ir::Shape{0} means that perm is (n-1...0) + // TODO This condition changes to perm.num_elements() == 0 + const auto is_regular_transpose = perm.shape() == ir::Shape{0}; + + // get mutable output operand + const auto output_idx = op.getOutputs().at(0); + auto &output = operands.at(output_idx); + if (!perm.isConstant() && !is_regular_transpose) + { + output.info().setDynamic(); + return; + } + + ir::Shape new_shape; + if (is_regular_transpose) + { + // Call by (n-1...0) + new_shape = shape_inference::inferTransposeShape(input.info().shape(), nullptr, 0); + } + else + { + // Check rank + if (input.info().shape().rank() != static_cast<int>(perm.info().shape().num_elements())) + { + throw std::runtime_error("StaticShapeInferer failed, bad rank size: " + + std::to_string(perm.info().shape().num_elements())); + } + + // set output shape, based on input and params + const auto perm_buf = reinterpret_cast<const int32_t *>(perm.data()->base()); + new_shape = shape_inference::inferTransposeShape(input.info().shape(), perm_buf, + perm.shape().num_elements()); + } + output.info().shape(new_shape); +} + +void StaticShapeInferer::visit(const ir::operation::Unpack &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + const auto input_idx{op.getInputs().at(0)}; + const auto &input = operands.at(input_idx); + const auto num = op.param().num; + const auto rank = input.shape().rank(); + const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis); + + assert(axis < rank); + if (axis < 0) + { + for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++) + { + const auto output_idx = op.getOutputs().at(out_tensor_idx); + ir::Operand &output = operands.at(output_idx); + output.info().setDynamic(); + } + return; + } + + ir::Shape new_shape = shape_inference::inferUnpackShape(input.info().shape(), axis, rank); + + // re-sizing output shape + for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++) + { + const auto output_idx = op.getOutputs().at(out_tensor_idx); + ir::Operand &output = operands.at(output_idx); + output.info().shape(new_shape); + } +} + +void StaticShapeInferer::visit(const ir::operation::While &op) +{ + auto body_input_observer = _subg_input_observers.at(op.param().body_subg_index).get(); + auto cond_input_observer = _subg_input_observers.at(op.param().cond_subg_index).get(); + // re-sizing input shapes of body subgraph + const auto &inputs = op.getInputs(); + std::vector<ir::OperandInfo> inputs_info; + const auto &graph = _lowered_subg->graph(); + for (size_t i = 0; i < inputs.size(); ++i) + { + const auto &operand_info = graph.operands().at(inputs.at(i)).info(); + inputs_info.emplace_back(operand_info); + } + + body_input_observer->updateShapes(inputs_info); + _child_inferers.at(op.param().body_subg_index)->infer(); + + // Check whether while operation's shapes are predictable + // This while op's outputs are also updated in the above function + // "_child_inferers.at(op.param().body_subg_index)->update()". That means that body's outputs and + // thils op's outputs must have the same shape. So we can predict whether body subgraphs will + // change at every step by comparing the shapes of inputs/outputs. If any of shape of body outputs + // and inputs are different Non-constant operands will be set to dynamic. + bool check_unpredictable_dynamic = false; + const auto &updated_outputs = op.getOutputs(); + assert(inputs_info.size() == updated_outputs.size()); + for (size_t i = 0; i < updated_outputs.size(); ++i) + { + const auto &input_info = inputs_info.at(i); + const auto &output_info = graph.operands().at(updated_outputs.at(i)).info(); + if (input_info.isDynamic() != output_info.isDynamic() || + input_info.shape() != output_info.shape()) + { + check_unpredictable_dynamic = true; + break; + } + } + + if (check_unpredictable_dynamic) + { + body_input_observer->updateShapes(inputs_info, check_unpredictable_dynamic); + _child_inferers.at(op.param().body_subg_index)->infer(); + } + cond_input_observer->updateShapes(inputs_info, check_unpredictable_dynamic); + _child_inferers.at(op.param().cond_subg_index)->infer(); +} + +void StaticShapeInferer::visit(const ir::operation::DetectionPostProcess &op) +{ + // TODO: NMS supports very limited input/output size. + ir::operation::DetectionPostProcess::Param param = op.param(); + + auto &operands = _lowered_subg->graph().operands(); + const int num_detected_boxes = param.max_detections * param.max_classes_per_detection; + + const auto output_idx1 = op.getOutputs().at(0); + auto &output1 = operands.at(output_idx1); + output1.info().shape({1, num_detected_boxes, 4}); + + const auto output_idx2 = op.getOutputs().at(1); + auto &output2 = operands.at(output_idx2); + output2.info().shape({1, num_detected_boxes}); + + const auto output_idx3 = op.getOutputs().at(2); + auto &output3 = operands.at(output_idx3); + output3.info().shape({1, num_detected_boxes}); + + const auto output_idx4 = op.getOutputs().at(3); + auto &output4 = operands.at(output_idx4); + output4.info().shape({1}); +} +void StaticShapeInferer::visit(const ir::operation::Bulk &op) +{ + auto &operands = _lowered_subg->graph().operands(); + + // TODO: support multiple inputs/outputs + const auto input_idx{op.getInputs().at(0)}; + const auto &input = operands.at(input_idx); + const auto output_idx = op.getOutputs().at(0); + ir::Operand &output = operands.at(output_idx); + + const auto &cur_input_shape = input.info().shape(); + auto origin_output_shape = op.param().origin_output_shapes[0]; + + // TODO: more check for valid batch request + if ((cur_input_shape.dim(0) < origin_output_shape.dim(0)) || + (cur_input_shape.dim(0) % origin_output_shape.dim(0) != 0)) + { + throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported batch size"); + } + size_t batch_multiplier = cur_input_shape.dim(0) / origin_output_shape.dim(0); + + ir::Shape new_shape; + new_shape.append(origin_output_shape.dim(0) * batch_multiplier); + for (int32_t d = 1; d < origin_output_shape.rank(); ++d) + new_shape.append(origin_output_shape.dim(d)); + + output.info().shape(new_shape); +} + +} // namespace compiler + +} // namespace onert diff --git a/runtime/onert/core/src/compiler/TensorBuilders.h b/runtime/onert/core/src/compiler/TensorBuilders.h deleted file mode 100644 index 3b0360b4b..000000000 --- a/runtime/onert/core/src/compiler/TensorBuilders.h +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __ONERT_COMPILER_TENSOR_BUILDERS_H__ -#define __ONERT_COMPILER_TENSOR_BUILDERS_H__ - -#include <unordered_set> -#include <memory> -#include "backend/BackendContext.h" -#include "backend/Backend.h" -#include "backend/controlflow/Config.h" -#include "backend/controlflow/TensorBuilder.h" -#include "util/logging.h" - -namespace onert -{ -namespace compiler -{ - -class TensorBuilders -{ -public: - TensorBuilders() = default; - - TensorBuilders(const onert::backend::BackendContexts &backend_contexts, bool include_controlflow) - { - for (const auto &e : backend_contexts) - { - if (e.first->config()->id() == backend::controlflow::Config::ID) - { - _cf_tensor_builder = std::dynamic_pointer_cast<backend::controlflow::TensorBuilder>( - e.second->tensor_builder); - if (include_controlflow) - _tensor_builders.insert(e.second->tensor_builder); - } - else - { - _tensor_builders.insert(e.second->tensor_builder); - } - } - } - - std::unordered_set<std::shared_ptr<onert::backend::ITensorBuilder>>::const_iterator begin() const - { - return _tensor_builders.cbegin(); - } - std::unordered_set<std::shared_ptr<onert::backend::ITensorBuilder>>::const_iterator end() const - { - return _tensor_builders.cend(); - } - - std::shared_ptr<backend::controlflow::TensorBuilder> getControlflowTensorBuilder() const - { - return _cf_tensor_builder; - } - -private: - std::unordered_set<std::shared_ptr<backend::ITensorBuilder>> _tensor_builders; - std::shared_ptr<backend::controlflow::TensorBuilder> _cf_tensor_builder; -}; - -} // namespace compiler -} // namespace onert - -#endif // __ONERT_COMPILER_TENSOR_BUILDERS_H__ diff --git a/runtime/onert/core/src/compiler/TensorRegistries.h b/runtime/onert/core/src/compiler/TensorRegistries.h index 8be87b081..4c30785df 100644 --- a/runtime/onert/core/src/compiler/TensorRegistries.h +++ b/runtime/onert/core/src/compiler/TensorRegistries.h @@ -17,13 +17,14 @@ #ifndef __ONERT_COMPILER_TENSOR_REGISTRIES_H__ #define __ONERT_COMPILER_TENSOR_REGISTRIES_H__ -#include <unordered_set> -#include <memory> -#include "backend/BackendContext.h" +#include "../backend/builtin/Config.h" +#include "../backend/builtin/TensorRegistry.h" + #include "backend/Backend.h" -#include "backend/controlflow/Config.h" -#include "backend/controlflow/TensorBuilder.h" -#include "backend/controlflow/TensorRegistry.h" +#include "backend/BackendContext.h" + +#include <memory> +#include <unordered_set> namespace onert { @@ -35,17 +36,16 @@ class TensorRegistries public: TensorRegistries() = default; - TensorRegistries(const onert::backend::BackendContexts &backend_contexts, - bool include_controlflow) + TensorRegistries(const onert::backend::BackendContexts &backend_contexts, bool include_builtin) { for (const auto &e : backend_contexts) { auto tensor_reg = e.second->tensor_registry; - if (e.first->config()->id() == backend::controlflow::Config::ID) + if (e.first->config()->id() == backend::builtin::Config::ID) { - _cf_tensor_reg = - std::dynamic_pointer_cast<backend::controlflow::TensorRegistry>(tensor_reg); - if (include_controlflow) + _builtin_tensor_reg = + std::dynamic_pointer_cast<backend::builtin::TensorRegistry>(tensor_reg); + if (include_builtin) _tensor_regs.insert(tensor_reg); } else @@ -64,14 +64,14 @@ public: return _tensor_regs.cend(); } - std::shared_ptr<backend::controlflow::TensorRegistry> getControlflowTensorRegistry() const + std::shared_ptr<backend::builtin::TensorRegistry> getBuiltinTensorRegistry() const { - return _cf_tensor_reg; + return _builtin_tensor_reg; } - std::shared_ptr<backend::ITensor> getITensor(ir::OperandIndex ind) const + backend::ITensor *getITensor(ir::OperandIndex ind) const { - for (auto &tensor_reg : _tensor_regs) + for (const auto &tensor_reg : _tensor_regs) { auto tensor = tensor_reg->getITensor(ind); if (tensor) @@ -82,7 +82,7 @@ public: private: std::unordered_set<std::shared_ptr<backend::ITensorRegistry>> _tensor_regs; - std::shared_ptr<backend::controlflow::TensorRegistry> _cf_tensor_reg; + std::shared_ptr<backend::builtin::TensorRegistry> _builtin_tensor_reg; }; } // namespace compiler diff --git a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc index 647669e46..ac131803f 100644 --- a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc +++ b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc @@ -17,8 +17,9 @@ #include "ConstantInsertionPass.h" #include "backend/Backend.h" -#include <ir/Graph.h> -#include <util/Utils.h> +#include "ir/Graph.h" +#include "util/Utils.h" +#include "util/logging.h" namespace onert { @@ -27,39 +28,30 @@ namespace compiler namespace pass { -void ConstantInsertionPass::callback(const ir::OperationIndex &node_index, ir::Operation &node) +void ConstantInsertionPass::callback(const ir::OperationIndex &node_index, ir::IOperation &node) { - const auto &op_sequence_index = _lowered_graph.op_seqs().getOperation(node_index); - const auto op_seq_lower_info = _lowered_graph.getLowerInfo(op_sequence_index); - const auto backend = op_seq_lower_info->backend(); - const auto layout = op_seq_lower_info->layout(); - const auto factor = ir::operand::PermuteFactor{backend, layout}; + const auto op_lower_info = _lowered_graph.lower_info().operation.getRawPtr(node_index); + const auto backend = op_lower_info->backend(); + const auto layout = op_lower_info->layout(); + const auto factor = PermuteFactor{backend, layout}; - for (const auto input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + for (const auto &input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) { auto &object = _graph.operands().at(input); - if (object.isConstant()) + const auto key = ReplaceKey{input, factor}; + if (object.isConstant() && (object.getUses().size() >= 2 || + _replace_operands_map.find(key) != _replace_operands_map.end())) { - const auto key = ReplaceKey{input, factor}; if (_replace_operands_map.count(key) == 0) { - auto new_object = object; - new_object.unsetDef(); - // TODO Remove const_case - const_cast<ir::OperationIndexSet &>(new_object.getUses()).clear(); + ir::Operand new_object(object); + new_object.clearDefUse(); const auto new_index = _graph.operands().emplace(new_object); _replace_operands_map[key] = new_index; } const auto replaced_input = _replace_operands_map[key]; - // Update op_seq - if (_lowered_graph.op_seqs().at(op_sequence_index).getInputs().contains(input)) - { - // All inputs of op_seq have the same PermuteFactor because those inputs are inputs of first - // operation - _lowered_graph.op_seqs().at(op_sequence_index).replaceInputs(input, replaced_input); - } // Update the same inputs of a node at once because inputs of an operation have the same // PermuteFactor @@ -69,6 +61,8 @@ void ConstantInsertionPass::callback(const ir::OperationIndex &node_index, ir::O auto &replaced_object = _graph.operands().at(replaced_input); replaced_object.insertUse(node_index); + VERBOSE(ConstInsertPass) << "New operand " << replaced_input << " added(copy of " << input + << ") for " << factor << std::endl; // Remove this node from uses of origin operand // Constant operand has no def. assert(!object.getDef().valid()); @@ -76,12 +70,16 @@ void ConstantInsertionPass::callback(const ir::OperationIndex &node_index, ir::O // Remove origin operand if (object.getUses().size() == 0) + { _graph.removeOperand(input); + VERBOSE(ConstInsertPass) << "Original operand " << input << " removed - no uses" + << std::endl; + } } } // Now this runtime does not support the node making output as constant - for (const auto &output : node.getOutputs()) + for (const auto &output : node.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) { UNUSED_RELEASE(output); assert(!_graph.operands().at(output).isConstant()); diff --git a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h index 052883c92..d5b9aa14e 100644 --- a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h +++ b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h @@ -17,7 +17,7 @@ #ifndef __ONERT_COMPILER_PASS_CONSTANT_INSERTION_PASS_H__ #define __ONERT_COMPILER_PASS_CONSTANT_INSERTION_PASS_H__ -#include <ir/operand/PermuteFactor.h> +#include <compiler/PermuteFactor.h> #include <ir/Index.h> #include "LoweredOperationPass.h" #include <unordered_map> @@ -39,13 +39,13 @@ public: std::string id() final { return "ConstantInsertionPass"; } public: - void callback(const ir::OperationIndex &index, ir::Operation &node) final; + void callback(const ir::OperationIndex &index, ir::IOperation &node) final; private: struct ReplaceKey { ir::OperandIndex index; - ir::operand::PermuteFactor factor; + PermuteFactor factor; bool operator==(const ReplaceKey &other) const { @@ -61,8 +61,7 @@ private: std::size_t operator()(const ReplaceKey &key) const noexcept { using std::hash; - return hash<ir::OperandIndex>()(key.index) ^ - (hash<ir::operand::PermuteFactor>()(key.factor) << 1); + return hash<ir::OperandIndex>()(key.index) ^ (hash<PermuteFactor>()(key.factor) << 1); } }; diff --git a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc index 1c1dbe0ee..32e32d0ef 100644 --- a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc +++ b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc @@ -18,8 +18,9 @@ #include "backend/Backend.h" #include <ir/Graph.h> -#include <ir/operand/PermuteFactor.h> +#include <compiler/PermuteFactor.h> #include <util/Utils.h> +#include "util/logging.h" namespace onert { @@ -28,25 +29,25 @@ namespace compiler namespace pass { -void ConstantLoweringPass::callback(const ir::OperationIndex &node_index, ir::Operation &node) +void ConstantLoweringPass::callback(const ir::OperationIndex &node_index, ir::IOperation &node) { - const auto &op_sequence_index = _lowered_graph.op_seqs().getOperation(node_index); - const auto op_seq_lower_info = _lowered_graph.getLowerInfo(op_sequence_index); - const auto backend = op_seq_lower_info->backend(); - const auto layout = op_seq_lower_info->layout(); - const auto factor = ir::operand::PermuteFactor{backend, layout}; + const auto op_lower_info = _lowered_graph.lower_info().operation.getRawPtr(node_index); + const auto backend = op_lower_info->backend(); + const auto layout = op_lower_info->layout(); + const auto factor = PermuteFactor{backend, layout}; // Now this runtime does not support the node making output of operation as constant - for (const auto input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + for (const auto &input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) { auto &object = _graph.operands().at(input); if (object.isConstant()) { // All constant operand are already assinged at each backend by ContantInsertionPass. So a // constant has `def` and `use` as the same PermuteFactor - _lowered_graph.setLowerInfo(input, std::make_unique<ir::operand::LowerInfo>()); - _lowered_graph.getLowerInfo(input)->addDefPermuteFactor(factor); - _lowered_graph.getLowerInfo(input)->addUsePermuteFactor(factor); + auto operand_li = std::make_unique<compiler::OperandLowerInfo>(); + operand_li->addDefPermuteFactor(factor); + operand_li->addUsePermuteFactor(factor); + _lowered_graph.lower_info().operand.set(input, std::move(operand_li)); } } } diff --git a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h index e17d776d1..d60a1033f 100644 --- a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h +++ b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h @@ -36,7 +36,7 @@ public: std::string id() final { return "ConstantLoweringPass"; } public: - void callback(const ir::OperationIndex &index, ir::Operation &node) final; + void callback(const ir::OperationIndex &index, ir::IOperation &node) final; }; } // namespace pass diff --git a/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc new file mode 100644 index 000000000..1448de473 --- /dev/null +++ b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConstantOutputPass.h" + +#include "ir/Graph.h" +#include "ir/operation/Permute.h" +#include "util/logging.h" + +namespace onert +{ +namespace compiler +{ +namespace pass +{ + +void ConstantOutputPass::callback(const ir::OperandIndex &ind, ir::Operand &obj) +{ + if (!_graph.getOutputs().contains(ind) || !obj.isConstant()) + return; + + auto permute_input_ind = _graph.addOperand(obj.shape(), obj.typeInfo()); + auto &permute_input_obj = _graph.operands().at(permute_input_ind); + + // Move the const data + permute_input_obj.data(obj.shareData()); + obj.releaseData(); + obj.info().setAsNonConst(); + + using ir::operation::Permute; + auto permute_obj = std::make_unique<Permute>(permute_input_ind, ind, Permute::Type::COPY); + auto permute_ind = _graph.operations().push(std::move(permute_obj)); + + permute_input_obj.insertUse(permute_ind); + obj.setDef(permute_ind); + + // Make the operations that uses this operand to use the generated operand + auto orig_uses = obj.getUses(); + for (auto &&use : orig_uses) + { + permute_input_obj.insertUse(use); + obj.removeUse(use); + _graph.operations().at(use).replaceInputs(ind, permute_input_ind); + } + + VERBOSE(ConstantOutputPass) << "Permute Op inserted for a constant ouput, node index : " + << permute_ind << std::endl; + VERBOSE(ConstantOutputPass) << " - Input (inserted) Operand : " << permute_input_ind + << std::endl; + VERBOSE(ConstantOutputPass) << " - Output(original) Operand : " << ind << std::endl; +} + +} // namespace pass +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/pass/ConstantOutputPass.h b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.h new file mode 100644 index 000000000..193dd3a68 --- /dev/null +++ b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_PASS_CONSTANT_OUTPUT_PASS_H__ +#define __ONERT_COMPILER_PASS_CONSTANT_OUTPUT_PASS_H__ + +#include "OperandPass.h" + +namespace onert +{ +namespace compiler +{ +namespace pass +{ + +/** + * @brief Pass to specially handle constant model outputs + * + * As an output buffer is given right before an execution but constant initialization is done at + * prepare phase, the current runtime structure cannot handle when an output is constant. + * To resolve this problem, this pass inserts a Permute layer with a const input and make the model + * output tensor to be its output. + * + * e.g.) + * + * ((Const Output)) + * + * becomes + * + * (Const) -> [Permute] -> ((Output)) + * + * Note that this is a mandatory pass for Graph. + */ +class ConstantOutputPass : public OperandPass +{ +public: + using OperandPass::OperandPass; + +public: + std::string id() final { return "ConstantOutputPass"; } + +public: + void callback(const ir::OperandIndex &i, ir::Operand &o) final; +}; + +} // namespace pass +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_PASS_CONSTANT_INSERTION_PASS_H__ diff --git a/runtime/onert/core/src/compiler/pass/IPass.h b/runtime/onert/core/src/compiler/pass/IPass.h new file mode 100644 index 000000000..77f5916fd --- /dev/null +++ b/runtime/onert/core/src/compiler/pass/IPass.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_PASS_IPASS_H__ +#define __ONERT_COMPILER_PASS_IPASS_H__ + +#include <string> + +namespace onert +{ +namespace compiler +{ +namespace pass +{ + +struct IPass +{ + virtual ~IPass() = default; + + virtual std::string id() = 0; + virtual void run() = 0; +}; + +} // namespace pass +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_PASS_IPASS_H__ diff --git a/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h b/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h index 0c5f7d745..64831a0ac 100644 --- a/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h +++ b/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h @@ -18,7 +18,7 @@ #define __ONERT_IR_PASS_LOWERED_OPERAND_PASS_H__ #include "OperandPass.h" -#include "compiler/LoweredGraph.h" +#include "compiler/ILoweredGraph.h" namespace onert { @@ -30,8 +30,8 @@ namespace pass class LoweredOperandPass : public OperandPass { public: - LoweredOperandPass(compiler::LoweredGraph &lowered_graph) - : OperandPass{lowered_graph.graph()}, _lowered_graph{lowered_graph} + LoweredOperandPass(compiler::ILoweredGraph &lowered_graph) + : OperandPass{lowered_graph.graph()}, _lowered_graph{lowered_graph} { // DO NOTHING } @@ -42,7 +42,7 @@ public: void callback(const ir::OperandIndex &i, ir::Operand &o) override = 0; protected: - compiler::LoweredGraph &_lowered_graph; + compiler::ILoweredGraph &_lowered_graph; }; } // namespace pass diff --git a/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h b/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h index 5c8569be2..27ca77c91 100644 --- a/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h +++ b/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h @@ -18,7 +18,7 @@ #define __ONERT_IR_PASS_LOWERED_OPERATION_PASS_H__ #include "OperationPass.h" -#include "compiler/LoweredGraph.h" +#include "compiler/ILoweredGraph.h" namespace onert { @@ -30,8 +30,8 @@ namespace pass class LoweredOperationPass : public OperationPass { public: - LoweredOperationPass(LoweredGraph &lowered_graph) - : OperationPass{lowered_graph.graph()}, _lowered_graph{lowered_graph} + LoweredOperationPass(ILoweredGraph &lowered_graph) + : OperationPass{lowered_graph.graph()}, _lowered_graph{lowered_graph} { // DO NOTHING } @@ -39,10 +39,10 @@ public: virtual ~LoweredOperationPass() = default; std::string id() override = 0; - void callback(const ir::OperationIndex &i, ir::Operation &o) override = 0; + void callback(const ir::OperationIndex &i, ir::IOperation &o) override = 0; protected: - LoweredGraph &_lowered_graph; + ILoweredGraph &_lowered_graph; }; } // namespace pass diff --git a/runtime/onert/core/src/compiler/pass/OddOutputPass.cc b/runtime/onert/core/src/compiler/pass/OddOutputPass.cc new file mode 100644 index 000000000..e2b3f6111 --- /dev/null +++ b/runtime/onert/core/src/compiler/pass/OddOutputPass.cc @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OddOutputPass.h" + +#include "ir/Graph.h" +#include "ir/operation/Permute.h" +#include "util/logging.h" +#include "util/Utils.h" + +namespace onert +{ +namespace compiler +{ +namespace pass +{ + +void OddOutputPass::run() +{ + auto &outputs = _graph.getOutputs(); + + VERBOSE(OddOutputPass) << "Case 1 : An operand which is a model output and a model input" + << std::endl; + for (const auto &ind : outputs) + { + if (_graph.getInputs().contains(ind)) + { + auto permute_output_ind = insertPermute(ind); + // Update the output to be newly added operand + _graph.getOutputs().replace(ind, permute_output_ind); + } + } + + VERBOSE(OddOutputPass) << "Case 2 : Two or more duplicated outputs" << std::endl; + std::unordered_set<ir::OperandIndex> occurence; + for (auto &&ind : outputs) + { + auto &obj = _graph.operands().at(ind); + if (occurence.count(ind) == 0) + { + occurence.insert(ind); + continue; + } + + // Panic when it is const, it must have been handled earlier in another pass + UNUSED_RELEASE(obj); + assert(!obj.isConstant()); + + auto permute_output_ind = insertPermute(ind); + ind = permute_output_ind; // Replace output index to fix output duplication + } +} + +ir::OperandIndex OddOutputPass::insertPermute(ir::OperandIndex ind) +{ + auto &obj = _graph.operands().at(ind); + auto output_ind = _graph.addOperand(obj.shape(), obj.typeInfo()); + auto &output_obj = _graph.operands().at(output_ind); + + using ir::operation::Permute; + auto permute_obj = std::make_unique<Permute>(ind, output_ind, Permute::Type::COPY); + auto permute_ind = _graph.operations().push(std::move(permute_obj)); + + output_obj.setDef(permute_ind); + obj.insertUse(permute_ind); + + VERBOSE(OddOutputPass) << "Permute Op inserted for a constant output, node index : " + << permute_ind << std::endl; + VERBOSE(OddOutputPass) << " - Input (original) Operand : " << ind << std::endl; + VERBOSE(OddOutputPass) << " - Output(inserted) Operand : " << output_ind << std::endl; + + return output_ind; +} + +} // namespace pass +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/pass/OddOutputPass.h b/runtime/onert/core/src/compiler/pass/OddOutputPass.h new file mode 100644 index 000000000..2accbac60 --- /dev/null +++ b/runtime/onert/core/src/compiler/pass/OddOutputPass.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_PASS_ODD_OUTPUT_PASS_H__ +#define __ONERT_COMPILER_PASS_ODD_OUTPUT_PASS_H__ + +#include <unordered_set> + +#include "Pass.h" +#include "ir/Index.h" + +namespace onert +{ +namespace compiler +{ +namespace pass +{ + +/** + * @brief Pass to specially handle odd outputs in a subgraph + * + * Runtime Graph IR requires every input or output must have distinct tensor index, this is onert's + * restriction. However we allow duplication of indices in the models(or API). So we should + * transform the graph after model-loading. + * + * This is necessary since our API lets users to set different buffers for each input and output so + * it is unavoidable that we must copy the value at runtime. + * + * Note that this is a mandatory pass for Graph. + * + * Case 1 : An operand which is a model output and a model input + * + * Create an operand and insert a Permute(copy) op between them. And change the output to be the + * newly generated operand. + * + * e.g.) + * + * ``` + * ((#0 Input0 and also Output0)) + * becomes + * ((#0 Input0)) -> [#0 Permute] -> ((#1 Output0)) + * ``` + * + * Case 2 : Two or more duplicated outputs + * + * Do the same with Case 1, but between two outputs of the same tensor index. + * + * e.g.) + * + * ``` + * ((#0 Input0)) -> [#0 Some Operation] -> ((#1 Output0 and also Output1)) + * becomes + * ((#0 Input0)) -> [#0 Some Operation] -> ((#1 Output0)) [#1 Permute] -> ((#2 Output1)) + * ``` + * + */ +class OddOutputPass : public Pass +{ +public: + using Pass::Pass; + +public: + std::string id() final { return "OddOutputPass"; } + +public: + void run() override; + +private: + ir::OperandIndex insertPermute(ir::OperandIndex input); +}; + +} // namespace pass +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_PASS_ODD_OUTPUT_PASS_H__ diff --git a/runtime/onert/core/src/compiler/pass/OperandPass.cc b/runtime/onert/core/src/compiler/pass/OperandPass.cc index 50c001c30..db8ebedcd 100644 --- a/runtime/onert/core/src/compiler/pass/OperandPass.cc +++ b/runtime/onert/core/src/compiler/pass/OperandPass.cc @@ -28,7 +28,7 @@ namespace pass void OperandPass::run() { _graph.operands().iterate( - [&](const ir::OperandIndex &index, ir::Operand &object) { callback(index, object); }); + [&](const ir::OperandIndex &index, ir::Operand &object) { callback(index, object); }); } } // namespace pass diff --git a/runtime/onert/core/src/compiler/pass/OperationPass.cc b/runtime/onert/core/src/compiler/pass/OperationPass.cc index d7a55cb22..bd9bcb4a4 100644 --- a/runtime/onert/core/src/compiler/pass/OperationPass.cc +++ b/runtime/onert/core/src/compiler/pass/OperationPass.cc @@ -17,7 +17,7 @@ #include "OperationPass.h" #include "ir/Index.h" -#include "ir/Operation.h" +#include "ir/IOperation.h" #include "ir/Graph.h" namespace onert @@ -30,7 +30,7 @@ namespace pass void OperationPass::run() { _graph.operations().iterate( - [&](const ir::OperationIndex &index, ir::Operation &node) { callback(index, node); }); + [&](const ir::OperationIndex &index, ir::IOperation &node) { callback(index, node); }); } } // namespace pass diff --git a/runtime/onert/core/src/compiler/pass/OperationPass.h b/runtime/onert/core/src/compiler/pass/OperationPass.h index ac4d818a2..0a00b11d1 100644 --- a/runtime/onert/core/src/compiler/pass/OperationPass.h +++ b/runtime/onert/core/src/compiler/pass/OperationPass.h @@ -29,7 +29,7 @@ namespace onert { namespace ir { -class Operation; +struct IOperation; } // namespace ir } // namespace onert @@ -62,7 +62,7 @@ public: * @param index is the index of a node in graph * @param node is the node in graph */ - virtual void callback(const ir::OperationIndex &index, ir::Operation &node) = 0; + virtual void callback(const ir::OperationIndex &index, ir::IOperation &node) = 0; /** * @brief Run the pass diff --git a/runtime/onert/core/src/compiler/pass/Pass.h b/runtime/onert/core/src/compiler/pass/Pass.h index 3f356c337..b34695c97 100644 --- a/runtime/onert/core/src/compiler/pass/Pass.h +++ b/runtime/onert/core/src/compiler/pass/Pass.h @@ -17,6 +17,8 @@ #ifndef __ONERT_COMPILER_PASS_PASS_H__ #define __ONERT_COMPILER_PASS_PASS_H__ +#include "IPass.h" + #include <string> namespace onert @@ -24,7 +26,7 @@ namespace onert namespace ir { class Graph; -} // namespace compiler +} // namespace ir } // namespace onert namespace onert @@ -34,7 +36,7 @@ namespace compiler namespace pass { -class Pass +class Pass : public IPass { public: Pass(ir::Graph &graph) : _graph{graph} {} diff --git a/runtime/onert/core/src/compiler/pass/PassRunner.cc b/runtime/onert/core/src/compiler/pass/PassRunner.cc new file mode 100644 index 000000000..cd1b82bb2 --- /dev/null +++ b/runtime/onert/core/src/compiler/pass/PassRunner.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PassRunner.h" + +namespace onert +{ +namespace compiler +{ +namespace pass +{ + +PassRunner &PassRunner::append(std::unique_ptr<IPass> pass) +{ + _passes.emplace_back(std::move(pass)); + return *this; +} + +void PassRunner::run() +{ + for (auto &&pass : _passes) + { + VERBOSE(PassRunner) << "Start running '" << pass->id() << "'" << std::endl; + pass->run(); + VERBOSE(PassRunner) << "Finished running '" << pass->id() << "'" << std::endl; + // TODO Dump graph? + } +} + +} // namespace pass +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/pass/PassRunner.h b/runtime/onert/core/src/compiler/pass/PassRunner.h new file mode 100644 index 000000000..03bfbe220 --- /dev/null +++ b/runtime/onert/core/src/compiler/pass/PassRunner.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_PASS_PASS_RUNNER_H__ +#define __ONERT_COMPILER_PASS_PASS_RUNNER_H__ + +#include <initializer_list> +#include <memory> +#include <vector> + +#include "IPass.h" +#include "util/logging.h" + +namespace onert +{ +namespace compiler +{ +namespace pass +{ + +/** + * @brief Composite passes with logging + */ +class PassRunner +{ +public: + PassRunner() = default; + PassRunner &append(std::unique_ptr<IPass> pass); + + void run(); + +private: + std::vector<std::unique_ptr<IPass>> _passes; +}; + +} // namespace pass +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_PASS_PASS_RUNNER_H__ diff --git a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc index f01697034..d9452c7f9 100644 --- a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc +++ b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc @@ -15,8 +15,8 @@ */ #include "PermutationEliminationPass.h" -#include "backend/controlflow/Config.h" +#include "backend/Backend.h" #include "util/logging.h" namespace onert @@ -26,7 +26,7 @@ namespace compiler namespace pass { -void PermutationEliminationPass::callback(const ir::OperationIndex &ind, ir::Operation &node) +void PermutationEliminationPass::callback(const ir::OperationIndex &ind, ir::IOperation &node) { _op_ind = ind; node.accept(*this); @@ -39,8 +39,9 @@ void PermutationEliminationPass::visit(const ir::operation::Permute &node) // Check if two tensors are both portable if not, we can't eliminate the node { - auto in_def_factor = _lowered_graph.getLowerInfo(in_operand)->def_factors().getOnlyElement(); - auto out_def_factor = _lowered_graph.getLowerInfo(out_operand)->def_factors().getOnlyElement(); + auto &operand_li_map = _lowered_graph.lower_info().operand; + auto in_def_factor = operand_li_map.getRawPtr(in_operand)->def_factors().getOnlyElement(); + auto out_def_factor = operand_li_map.getRawPtr(out_operand)->def_factors().getOnlyElement(); auto in_config = in_def_factor.backend()->config(); auto out_config = out_def_factor.backend()->config(); @@ -53,59 +54,50 @@ void PermutationEliminationPass::visit(const ir::operation::Permute &node) if (_graph.getOutputs().contains(out_operand)) { + // If the input is a const, we cannot remove it since we cannot put the constant data in the + // output buffer during prepare phase. + auto permute_input = node.getInputs().at(0); + if (_graph.operands().at(permute_input).isConstant()) + return; + // If the input is a model input, we cannot remove it since our API lets users to set different + // buffers for inputs and outputs even though one tensor is both at the same time. + auto permute_output = node.getOutputs().at(0); + if (_graph.getInputs().contains(permute_input) && _graph.getOutputs().contains(permute_output)) + return; + // Likewise, if copying between outputs to outputs, keep it. + if (_graph.getOutputs().contains(permute_input) && _graph.getOutputs().contains(permute_output)) + return; + // Exceptional case : When the output operand is a model output // In this case we keep the output and remove the input auto &out_operand_obj = _graph.operands().at(out_operand); assert(out_operand_obj.getDef() == _op_ind); out_operand_obj.unsetDef(); - _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) { - if (!op_seq.getOutputs().contains(in_operand)) + _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) { + if (!op.getOutputs().contains(in_operand)) return; - - // Update OpSequence/ir::Operation edges and ir::Operand edges - op_seq.replaceOutputs(in_operand, out_operand); - for (auto op : op_seq.operations()) - { - auto &operation_obj = _graph.operations().at(op); - if (operation_obj.getOutputs().contains(in_operand)) - { - operation_obj.replaceOutputs(in_operand, out_operand); - out_operand_obj.setDef(op); - } - } + // Update Operation and Operand edges + op.replaceOutputs(in_operand, out_operand); + out_operand_obj.setDef(op_ind); }); - // Remove Permute operation, enclosing OpSequence and the operand + // Remove Permute operation and the operand { _graph.removeOperand(in_operand); - - auto op_seq_ind = _lowered_graph.op_seqs().getOperation(_op_ind); - // Assumes enclosing OpSequence contatins just this Permute operation - assert(_lowered_graph.op_seqs().at(op_seq_ind).size() == 1); - _lowered_graph.op_seqs().remove(op_seq_ind); _graph.operations().remove(_op_ind); } - _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) { - if (!op_seq.getInputs().contains(in_operand)) + _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) { + if (!op.getInputs().contains(in_operand)) return; - - op_seq.replaceInputs(in_operand, out_operand); - for (auto op : op_seq.operations()) - { - auto &operation_obj = _graph.operations().at(op); - if (operation_obj.getInputs().contains(in_operand)) - { - operation_obj.replaceInputs(in_operand, out_operand); - out_operand_obj.insertUse(op); - } - } + op.replaceInputs(in_operand, out_operand); + out_operand_obj.insertUse(op_ind); }); VERBOSE(removePermute) << "Permute Op removed, node index : " << _op_ind << std::endl; - VERBOSE(removePermute) << " - Input (removed) ir::Operand : " << in_operand << std::endl; - VERBOSE(removePermute) << " - Output(kept) ir::Operand : " << out_operand << std::endl; + VERBOSE(removePermute) << " - Input (removed) Operand : " << in_operand << std::endl; + VERBOSE(removePermute) << " - Output(kept) Operand : " << out_operand << std::endl; } else { @@ -114,37 +106,23 @@ void PermutationEliminationPass::visit(const ir::operation::Permute &node) auto &in_operand_obj = _graph.operands().at(in_operand); in_operand_obj.removeUse(_op_ind); - // Make OpSequences(that use the output) use the input - _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) { - if (!op_seq.getInputs().contains(out_operand)) + // Make operations(that use the output) use the input + _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) { + if (!op.getInputs().contains(out_operand)) return; - - op_seq.replaceInputs(out_operand, in_operand); - for (auto op : op_seq.operations()) - { - auto &operation_obj = _graph.operations().at(op); - if (operation_obj.getInputs().contains(out_operand)) - { - operation_obj.replaceInputs(out_operand, in_operand); - in_operand_obj.insertUse(op); - } - } + op.replaceInputs(out_operand, in_operand); + in_operand_obj.insertUse(op_ind); }); - // Remove Permute operation, enclosing OpSequence and the operand + // Remove the Permute operation and out_operand { _graph.removeOperand(out_operand); - - auto op_seq_ind = _lowered_graph.op_seqs().getOperation(_op_ind); - // Assumes enclosing OpSequence contatins just this Permute operation - assert(_lowered_graph.op_seqs().at(op_seq_ind).size() == 1); - _lowered_graph.op_seqs().remove(op_seq_ind); _graph.operations().remove(_op_ind); } - VERBOSE(removePermute) << "Permute Op removed, node index : " << _op_ind << std::endl; - VERBOSE(removePermute) << " - Input (kept) ir::Operand : " << in_operand << std::endl; - VERBOSE(removePermute) << " - Output(removed) ir::Operand : " << out_operand << std::endl; + VERBOSE(removePermute) << "Permute Op removed : " << _op_ind << std::endl; + VERBOSE(removePermute) << " - Input (kept) Operand : " << in_operand << std::endl; + VERBOSE(removePermute) << " - Output(removed) Operand : " << out_operand << std::endl; } } diff --git a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h index 29daf1a82..18ba99804 100644 --- a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h +++ b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h @@ -35,7 +35,7 @@ namespace pass * are compatible and layouts match. * * Permute input tensor is kept and the output is removed for all the cases, except model outputs. - * As all output tensors have to be controlflow backend, so the output is kept. + * As all output tensors have to be builtin backend, so the output is kept. * * @note This is an optimization pass which means that everything should work fine even if this pass * was skipped. @@ -49,7 +49,7 @@ public: std::string id() final { return "PermutationEliminationPass"; } public: - void callback(const ir::OperationIndex &i, ir::Operation &n) final; + void callback(const ir::OperationIndex &i, ir::IOperation &n) final; private: void visit(const ir::operation::Permute &) final; diff --git a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc index c83a72ada..f5ad7e636 100644 --- a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc +++ b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc @@ -9,6 +9,7 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -16,18 +17,16 @@ #include "PermutationInsertionPass.h" -#include <cassert> -#include <utility> -#include <unordered_map> +#include "../../backend/builtin/Config.h" -#include "backend/controlflow/Config.h" -#include "ir/Operand.h" -#include "ir/operation/LowerInfo.h" -#include "ir/Graph.h" -#include "backend/IConfig.h" +#include "compiler/OperationLowerInfo.h" +#include "ir/operation/Permute.h" #include "util/logging.h" + +#include <cassert> #include <memory> -#include "ir/operation/Permute.h" +#include <unordered_map> +#include <utility> namespace onert { @@ -38,7 +37,8 @@ namespace pass void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Operand &object) { - auto &&operand_li = _lowered_graph.getLowerInfo(index); + auto &operand_li_map = _lowered_graph.lower_info().operand; + auto &&operand_li = operand_li_map.getRawPtr(index); assert(operand_li); // NOTE Later, constants also will have Def @@ -51,16 +51,16 @@ void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Opera std::list<ir::OperationIndex> permute_indexes; // Build a map for all necessary type of operands - std::unordered_map<ir::operand::PermuteFactor, ir::OperandIndex> factor_to_index; + std::unordered_map<PermuteFactor, ir::OperandIndex> factor_to_index; { assert(operand_li->def_factors().size() == 1); - for (auto factor : operand_li->def_factors()) + for (auto &&factor : operand_li->def_factors()) { factor_to_index.emplace(factor, index); } auto insert_set = operand_li->use_factors() - operand_li->def_factors(); - for (auto factor : insert_set) + for (auto &&factor : insert_set) { const auto permute_operation_index = insertPermute(index, factor); permute_indexes.push_back(permute_operation_index); @@ -75,33 +75,23 @@ void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Opera std::list<ir::OperationIndex> remove_list; auto uses = object.getUses(); - for (auto use : uses) + for (auto &&use : uses) { // If permute operation, ignore it if (std::find(permute_indexes.begin(), permute_indexes.end(), use) != permute_indexes.end()) continue; auto &operation = _graph.operations().at(use); - assert(_lowered_graph.op_seqs().containsOperation(use)); - auto op_seq_index = _lowered_graph.op_seqs().getOperation(use); - auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index); - assert(op_seq_li); - const auto op_seq_layout = op_seq_li->layout(); - const backend::Backend *backend = op_seq_li->backend(); + auto op_li = _lowered_graph.lower_info().operation.getRawPtr(use); + assert(op_li); + const auto op_layout = op_li->layout(); + const backend::Backend *backend = op_li->backend(); assert(backend); - auto use_node_inputs = operation.getInputs(); - assert(use_node_inputs.contains(index)); + assert(operation.getInputs().contains(index)); - auto new_index = factor_to_index.at({backend, op_seq_layout}); + auto new_index = factor_to_index.at({backend, op_layout}); if (index != new_index) { - // Update from op_seq - // Replace the same inputs of an OpSequence at once for the following reasons: - // 1. An OpSequence's inputs are the same inputs of first operation - // 2. An OpSequence may have inputs as the same operand (2 or more). - // 3. The same inputs of OpSequence have the same PermuteFactor. - _lowered_graph.op_seqs().at(op_seq_index).replaceInputs(index, new_index); - // Update from operation // Replace the same inputs of an operation at once for the following reasons: // No. 2 and 3 above @@ -109,63 +99,69 @@ void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Opera // Update from operand remove_list.push_back( - use); // Removal should be done in another loop since we are in the loop + use); // Removal should be done in another loop since we are in the loop _graph.operands().at(new_index).insertUse(use); } } - for (auto &operation : remove_list) + for (const auto &operation_index : remove_list) { - object.removeUse(operation); + object.removeUse(operation_index); } } } ir::OperationIndex PermutationInsertionPass::insertPermute(const ir::OperandIndex &operand_index, - const ir::operand::PermuteFactor &factor) + const PermuteFactor &factor) { - assert(!_graph.isBuildingPhase()); - auto &operand = _graph.operands().at(operand_index); // Generate output operand and permute operation auto out_operand_index = _graph.addOperand(operand.shape(), operand.typeInfo()); - // change model output if operand_index is model output index + // change model output if operand_index is model output index and the out operand is builtin + // backend auto &model_outputs = _graph.getOutputs(); - if (model_outputs.contains(operand_index)) + const backend::Backend *builtin_backend = compiler::BackendManager::get().getBuiltin(); + assert(builtin_backend->config()->id() == onert::backend::builtin::Config::ID); + + if (model_outputs.contains(operand_index) && factor.backend() == builtin_backend) { model_outputs.replace(operand_index, out_operand_index); } + auto &operand_li_map = _lowered_graph.lower_info().operand; + // Find Permute information - auto input_factor = _lowered_graph.getLowerInfo(operand_index)->def_factors().getOnlyElement(); + auto input_factor = operand_li_map.getRawPtr(operand_index)->def_factors().getOnlyElement(); auto input_backend = input_factor.backend(); auto output_backend = factor.backend(); // NOTE Permute may not have specific layout because the layout of input and output may be // different. const auto permute_node_layout = ir::Layout::UNKNOWN; // NOTE If one backend supports several layout, the backend must support Permute operation - const backend::Backend *permute_node_backend = compiler::BackendManager::get().getControlflow(); + const backend::Backend *permute_node_backend = compiler::BackendManager::get().getBuiltin(); + assert(permute_node_backend->config()->id() == onert::backend::builtin::Config::ID); + if (input_backend == output_backend) { permute_node_backend = input_backend; } - const ir::operand::PermuteFactor permute_node_factor{permute_node_backend, permute_node_layout}; + const PermuteFactor permute_node_factor{permute_node_backend, permute_node_layout}; // Update LowerInfo of input operand - auto operand_lower_info = _lowered_graph.getLowerInfo(operand_index); + auto operand_lower_info = operand_li_map.getRawPtr(operand_index); operand_lower_info->removeUsePermuteFactor(factor); operand_lower_info->addUsePermuteFactor(permute_node_factor); // Update LowerInfo of output operand - auto out_operand_li = std::make_unique<ir::operand::LowerInfo>(); + auto out_operand_li = std::make_unique<compiler::OperandLowerInfo>(); // The input and output factors of all nodes will be the same except Permute. So Tensor's // allocators allocates memory using only the information of def permutation factor now. // TODO Change param to permute_node_factor out_operand_li->addDefPermuteFactor(factor); out_operand_li->addUsePermuteFactor(factor); - _lowered_graph.setLowerInfo(out_operand_index, std::move(out_operand_li)); + operand_li_map.set(out_operand_index, std::move(out_operand_li)); // Insert permute operation to the graph const auto input_layout = input_factor.layout(); @@ -188,20 +184,18 @@ ir::OperationIndex PermutationInsertionPass::insertPermute(const ir::OperandInde auto insert_node = std::make_unique<Permute>(operand_index, out_operand_index, permute_type); auto node_index = _graph.operations().push(std::move(insert_node)); - const auto &node = _graph.operations().at(node_index); VERBOSE_F() << "Permute Op inserted, node index : " << node_index << std::endl; - VERBOSE_F() << " - Input (original) Operand : " << operand_index << std::endl; - VERBOSE_F() << " - Output(inserted) Operand : " << out_operand_index << std::endl; + VERBOSE_F() << " - Input (original) Operand : " << operand_index << "(" + << input_factor.backend()->config()->id() << ")" << std::endl; + VERBOSE_F() << " - Output(inserted) Operand : " << out_operand_index << "(" + << factor.backend()->config()->id() << ")" << std::endl; - // OpSequence + // Operation LowerInfo { - auto op_seq_index = _lowered_graph.op_seqs().emplace(node_index, permute_node_layout); - auto &op_seq = _lowered_graph.op_seqs().at(op_seq_index); - op_seq.setInputs(node.getInputs()); - op_seq.setOutputs(node.getOutputs()); - _lowered_graph.setLowerInfo(op_seq_index, std::make_unique<ir::operation::LowerInfo>( - permute_node_backend, permute_node_layout)); + auto &operation_li_map = _lowered_graph.lower_info().operation; + operation_li_map.set(node_index, std::make_unique<compiler::OperationLowerInfo>( + permute_node_backend, permute_node_layout)); } // Update Use/Def info diff --git a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.h b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.h index 758515385..ee0a1464c 100644 --- a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.h +++ b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.h @@ -20,7 +20,7 @@ #include "LoweredOperandPass.h" #include "compiler/BackendManager.h" #include "ir/Operand.h" -#include "ir/operand/PermuteFactor.h" +#include "compiler/PermuteFactor.h" namespace onert { @@ -48,7 +48,7 @@ private: * @return ir::OperationIndex */ ir::OperationIndex insertPermute(const ir::OperandIndex &operand_index, - const ir::operand::PermuteFactor &factor); + const PermuteFactor &factor); }; } // namespace pass diff --git a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc index c5c95c726..f014d29d3 100644 --- a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc +++ b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc @@ -30,10 +30,10 @@ namespace pass using namespace ir; -void PermutationOperationPass::callback(const OperationIndex &, Operation &node) +void PermutationOperationPass::callback(const OperationIndex &, IOperation &node) { node.accept(*this); -}; +} // TODO Remove this. Expanding ranks of Operand is dangerous void PermutationOperationPass::applyExpandRanks(const Operation &node) @@ -43,9 +43,8 @@ void PermutationOperationPass::applyExpandRanks(const Operation &node) assert(output.getDef().valid()); const auto node_index = output.getDef(); - const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index); - const auto frontend_layout = _lowered_graph.op_seqs().at(op_seq_index).getLayout(); - const auto backend_layout = _lowered_graph.getLowerInfo(op_seq_index)->layout(); + const auto frontend_layout = _graph.layout(); + const auto backend_layout = _lowered_graph.lower_info().operation.getRawPtr(node_index)->layout(); if (frontend_layout == backend_layout) { @@ -84,10 +83,11 @@ void PermutationOperationPass::changeToKeepLayout(const Operation &node) assert(output_obj.getDef().valid()); const auto node_index = output_obj.getDef(); - const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index); - const auto frontend_layout = _lowered_graph.op_seqs().at(op_seq_index).getLayout(); - const auto backend_layout = _lowered_graph.getLowerInfo(op_seq_index)->layout(); + auto &operation_li_map = _lowered_graph.lower_info().operation; + auto &operand_li_map = _lowered_graph.lower_info().operand; + const auto frontend_layout = _graph.layout(); + const auto backend_layout = operation_li_map.getRawPtr(node_index)->layout(); if (frontend_layout == backend_layout) { @@ -97,96 +97,27 @@ void PermutationOperationPass::changeToKeepLayout(const Operation &node) // Permutation changing layout beyond 4-D is not supported yet assert(output_obj.shape().rank() <= 4); - // Divide op_seq based on target operation - { - auto &prev_op_seq = _lowered_graph.op_seqs().at(op_seq_index); - auto &operations = _lowered_graph.graph().operations(); - - // Create new op_seq and move information from existing op_seq to new op_seq if target - // node is the end of op_seq - auto it = prev_op_seq.begin(); - // Find iterator of target node in op_seq - while (*(it++) != node_index) - ; - if (it != prev_op_seq.end()) - { - const auto &target_op_idx = *it; - const auto &target_node = operations.at(target_op_idx); - const auto &next_op_seq_index = - _lowered_graph.op_seqs().emplace(target_op_idx, prev_op_seq.getLayout()); - auto &next_op_seq = _lowered_graph.op_seqs().at(next_op_seq_index); - next_op_seq.setInputs(target_node.getInputs()); - next_op_seq.setOutputs(target_node.getOutputs()); - - std::vector<OperationIndex> remove_list; - remove_list.emplace_back(target_op_idx); - while (++it != prev_op_seq.end()) - { - next_op_seq.appendOperation(target_op_idx); - next_op_seq.setOutputs(target_node.getOutputs()); - remove_list.emplace_back(target_op_idx); - } - - prev_op_seq.setOutputs(node.getOutputs()); - for (const auto &index : remove_list) - { - prev_op_seq.remove(index); - } - - const auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index); - _lowered_graph.setLowerInfo( - next_op_seq_index, - std::make_unique<ir::operation::LowerInfo>(op_seq_li->backend(), op_seq_li->layout())); - } - } - - // Remove target operation from op_seq and insert the target operation to new op_seq + // Change PermuteFactors of operands and the operation of target node { - const auto backend = _lowered_graph.getLowerInfo(op_seq_index)->backend(); + const auto op_li = operation_li_map.getRawPtr(node_index); + const auto backend = op_li->backend(); - // Remove target operation from op_sequence - _lowered_graph.op_seqs().removeFromOpSequence(node_index); + operation_li_map.set(node_index, + std::make_unique<compiler::OperationLowerInfo>(backend, frontend_layout)); - if (!_lowered_graph.op_seqs().exist(op_seq_index)) - { - // Remove lowerinfo for op_seq of target operation if the op_seq does not exist - _lowered_graph.removeLowerInfo(op_seq_index); - } - else - { - // Update op_seq of target operation if the op_seq exists - auto &prev_op_seq = _lowered_graph.op_seqs().at(op_seq_index); - const auto &last_node_idx = *(--prev_op_seq.end()); - const auto &last_node = _lowered_graph.graph().operations().at(last_node_idx); - prev_op_seq.setOutputs(last_node.getOutputs()); - } - - // Create new op_seq and set information to the op_seq - auto new_op_seq_index = _lowered_graph.op_seqs().emplace(node_index, frontend_layout); - auto &new_op_seq = _lowered_graph.op_seqs().at(new_op_seq_index); - new_op_seq.setInputs(node.getInputs()); - new_op_seq.setOutputs(node.getOutputs()); - _lowered_graph.setLowerInfo( - new_op_seq_index, std::make_unique<ir::operation::LowerInfo>(backend, frontend_layout)); - } - - // Change PermuteFactors of operands of target node - { - const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index); - const auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index); - const auto backend = op_seq_li->backend(); - const operand::PermuteFactor removed_factor{backend, backend_layout}; - const operand::PermuteFactor new_factor{backend, frontend_layout}; + const PermuteFactor removed_factor{backend, backend_layout}; + const PermuteFactor new_factor{backend, frontend_layout}; for (const auto &input : node.getInputs() | Remove::DUPLICATED | Remove::UNDEFINED) { + // Check if it can be removed by checking if the operand is used by another operation and + // it uses the same backend and layout bool canRemove = true; for (const auto &use : _graph.operands().at(input).getUses()) { if (use != node_index) { - const auto &use_op_seq_index = _lowered_graph.op_seqs().getOperation(use); - auto use_op_seq_li = _lowered_graph.getLowerInfo(use_op_seq_index); - if (use_op_seq_li->backend() == backend && use_op_seq_li->layout() == backend_layout) + auto use_op_li = operation_li_map.getRawPtr(use); + if (use_op_li->backend() == backend && use_op_li->layout() == backend_layout) { canRemove = false; break; @@ -194,27 +125,27 @@ void PermutationOperationPass::changeToKeepLayout(const Operation &node) } } - auto lower_info = _lowered_graph.getLowerInfo(input); + auto input_li = operand_li_map.getRawPtr(input); if (canRemove) { - lower_info->removeUsePermuteFactor(removed_factor); + input_li->removeUsePermuteFactor(removed_factor); } - lower_info->addUsePermuteFactor(new_factor); + input_li->addUsePermuteFactor(new_factor); // Whether if node's input is an input of model or a constant if (!_graph.operands().at(input).getDef().valid() && - (lower_info->def_factors().size() == 1 && - lower_info->def_factors().getOnlyElement() == removed_factor)) + (input_li->def_factors().size() == 1 && + input_li->def_factors().getOnlyElement() == removed_factor)) { assert(_graph.getInputs().contains(input) || _graph.operands().at(input).isConstant()); - lower_info->removeDefPermuteFactor(removed_factor); - lower_info->addDefPermuteFactor(new_factor); + input_li->removeDefPermuteFactor(removed_factor); + input_li->addDefPermuteFactor(new_factor); } } - for (const auto &output : node.getOutputs() | Remove::DUPLICATED) + for (const auto &output : node.getOutputs() | Remove::DUPLICATED | Remove::UNDEFINED) { - auto lower_info = _lowered_graph.getLowerInfo(output); + auto lower_info = operand_li_map.getRawPtr(output); lower_info->removeDefPermuteFactor(removed_factor); lower_info->addDefPermuteFactor(new_factor); @@ -279,6 +210,18 @@ void PermutationOperationPass::visit(const ir::operation::Gather &node) } } +void PermutationOperationPass::visit(const ir::operation::OneHot &node) +{ + const auto &output_ind = node.getOutputs().at(0); + const auto &output_obj = _graph.operands().at(output_ind); + const auto &output_shape = output_obj.shape(); + + if (output_shape.rank() >= 4) + { + changeToKeepLayout(node); + } +} + void PermutationOperationPass::visit(const ir::operation::Pack &node) { const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT); diff --git a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h index 2dd76b971..e253a77ad 100644 --- a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h +++ b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h @@ -36,7 +36,7 @@ public: std::string id() final { return "PermutationOperationPass"; } public: - void callback(const ir::OperationIndex &i, ir::Operation &n) final; + void callback(const ir::OperationIndex &i, ir::IOperation &n) final; public: void visit(const ir::operation::BinaryArithmetic &) final; @@ -44,6 +44,7 @@ public: void visit(const ir::operation::Concat &) final; void visit(const ir::operation::ElementwiseBinary &) final; void visit(const ir::operation::ElementwiseUnary &) final; + void visit(const ir::operation::OneHot &) final; void visit(const ir::operation::Pack &) final; void visit(const ir::operation::PReLU &) final; void visit(const ir::operation::SquaredDifference &) final; diff --git a/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc new file mode 100644 index 000000000..162c4e7ef --- /dev/null +++ b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Pass.h" + +#include "UnusedOperandEliminationPass.h" +#include "ir/Index.h" +#include "util/Set.h" +#include "ir/Graph.h" + +/** + * @file UnusedOperandEliminationPass.cc + * @brief This file contains UnusedOperandEliminationPass class implementation + */ + +namespace onert +{ +namespace compiler +{ +namespace pass +{ + +void UnusedOperandEliminationPass::run() +{ + util::Set<ir::OperandIndex> used; + + _graph.operations().iterate([&](const ir::OperationIndex &, const ir::IOperation &node) { + for (auto &&ind : (node.getInputs() + node.getOutputs()) | ir::Remove::UNDEFINED) + { + used.add(ind); + } + }); + + // Graph's inputs/outputs are always considered as used + for (auto &&ind : (_graph.getInputs() + _graph.getOutputs()) | ir::Remove::UNDEFINED) + { + used.add(ind); + } + + _graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &) { + if (!used.contains(ind)) + { + VERBOSE() << "Remove unused operand " << ind << std::endl; + _graph.operands().remove(ind); + } + }); +} + +} // namespace pass +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.h b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.h new file mode 100644 index 000000000..8078f4246 --- /dev/null +++ b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file UnusedOperandEliminationPass.h + * @brief This file contains UnusedOperandEliminationPass class + */ + +#ifndef __ONERT_COMPILER_PASS_UNUSED_OPERAND_ELIMINATION_PASS_H__ +#define __ONERT_COMPILER_PASS_UNUSED_OPERAND_ELIMINATION_PASS_H__ + +#include "Pass.h" + +namespace onert +{ +namespace compiler +{ +namespace pass +{ + +/** + * @brief A pass to eliminate unused operands from the graph + * + * Remove operands that are not used by any operations, except Graph inputs/outputs. + * + */ +class UnusedOperandEliminationPass : public Pass +{ +public: + using Pass::Pass; + +public: + std::string id() override { return "UnusedOperandEliminationPass"; } + void run() final; +}; + +} // namespace pass +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_PASS_UNUSED_OPERAND_ELIMINATION_PASS_H__ diff --git a/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.test.cc b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.test.cc new file mode 100644 index 000000000..572b4df24 --- /dev/null +++ b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.test.cc @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "UnusedOperandEliminationPass.h" + +#include "ir/Graph.h" + +#include <gtest/gtest.h> + +using namespace onert::ir; +using namespace onert::compiler::pass; + +TEST(UnusedOperandEliminationPass, Simple) +{ + Graph graph; + + // Add tensors + Shape shape{1, 2, 2, 1}; + TypeInfo type{DataType::FLOAT32}; + auto in = graph.addOperand(shape, type); + auto out = graph.addOperand(shape, type); + + auto unused = graph.addOperand(shape, type); + + // Set model inputs/outputs + graph.addInput(in); + graph.addOutput(out); + + UnusedOperandEliminationPass{graph}.run(); + + ASSERT_TRUE(graph.operands().exist(in)); + ASSERT_TRUE(graph.operands().exist(out)); + ASSERT_FALSE(graph.operands().exist(unused)); +} diff --git a/runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc b/runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc new file mode 100644 index 000000000..8b368c440 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc @@ -0,0 +1,286 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compiler/train/LoweredTrainableGraph.h" + +#include "../ManualScheduler.h" +#include "../pass/ConstantInsertionPass.h" +#include "../pass/ConstantLoweringPass.h" +#include "../pass/PassRunner.h" +#include "../pass/PermutationEliminationPass.h" +#include "../pass/PermutationInsertionPass.h" +#include "../pass/PermutationOperationPass.h" +#include "../../backend/builtin/Config.h" +#include "../../dumper/text/GraphDumper.h" +#include "../../ir/verifier/Verifier.h" +#include "TrainableOperationConverter.h" + +#include "backend/Backend.h" +#include "backend/train/ITrainableBackend.h" +#include "compiler/BackendResolver.h" +#include "util/logging.h" + +#include <cassert> +#include <sstream> + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +LoweredTrainableGraph::LoweredTrainableGraph(ir::train::TrainableGraph &graph, + const CompilerOptions &options) + : _trainable_graph{graph} +{ + lowerGraph(options); +} + +void LoweredTrainableGraph::lowerGraph(const CompilerOptions &options) +{ + // Build backend contexts + auto &backend_manager = BackendManager::get(); + // Create contexts for other backends + for (auto &&backend_str : options.backend_list) + { + backend_manager.loadBackend(backend_str); + auto backend = backend_manager.get(backend_str); + + // TODO As the default value of backend list contains "cpu", "acl_cl" and "acl_neon", and some + // are not available on x64 or some other platforms. So this may be a workaround for x64 and + // we should change it back(throw if backend is not loaded) later. + if (!backend) + { + VERBOSE(LoweredTrainableGraph) << "Cannot load backend - " << backend_str << std::endl; + continue; + } + } + if (backend_manager.num_backends() == 0) + throw std::runtime_error{"No available backends loaded."}; + + // TODO Move "schedule" phase out of here + // TODO Scheduling + std::unique_ptr<BackendResolver> backend_resolver; + auto all_backends = backend_manager.getAll(); + + auto scheduler = ManualScheduler(all_backends, options); + backend_resolver = scheduler.schedule(_trainable_graph.graph()); + + // Check if backends are trainable + _trainable_graph.operations().iterate( + [&](const ir::OperationIndex &op_ind, const ir::IOperation &) { + const auto backend = backend_resolver->getBackend(op_ind); + + // TODO Remove dynamic_cast + if (dynamic_cast<const backend::train::ITrainableBackend *>(backend) == nullptr) + { + throw std::runtime_error(backend->config()->id() + "backend does not support training"); + } + }); + + makeLowerInfo(*backend_resolver); + VERBOSE(LoweredTrainableGraph) << "dump before mandatory passes" << std::endl; + dumper::text::dumpLoweredGraph(*this); + + // Mandatory passes - kind of legalization(?) + compiler::pass::PassRunner{} + .append(std::make_unique<compiler::pass::ConstantInsertionPass>(*this)) + .append(std::make_unique<compiler::pass::ConstantLoweringPass>(*this)) + .append(std::make_unique<compiler::pass::PermutationOperationPass>(*this)) + .append(std::make_unique<compiler::pass::PermutationInsertionPass>(*this)) + .run(); + + // TODO Move converting Permute op into PermutationInsertionPass + auto op_converter = TrainableOperationConverter{_trainable_graph, nullptr}; + _trainable_graph.operations().iterate( + [&](const onert::ir::OperationIndex &index, const onert::ir::IOperation &op) { + if (op.opcode() == ir::OpCode::Permute) + { + auto trainable_op = op_converter(op); + trainable_op->enableBackward(); + auto gen_index = _trainable_graph.replaceOperation(index, std::move(trainable_op)); + UNUSED_RELEASE(gen_index); + assert(gen_index == index); + } + }); + + dumpLowerInfo(); + + // Optimization passes (optional) + compiler::pass::PassRunner{} + .append(std::make_unique<compiler::pass::PermutationEliminationPass>(*this)) + .run(); + + // TODO Update LowerInfo for training + + VERBOSE(LoweredTrainableGraph) << "Dump after all the passes" << std::endl; + for (auto &&operand : _trainable_graph.getInputs()) + VERBOSE(LoweredTrainableGraph) << "Graph Input : " << operand << std::endl; + for (auto &&operand : _trainable_graph.getOutputs()) + VERBOSE(LoweredTrainableGraph) << "Graph Output : " << operand << std::endl; + dumper::text::dumpLoweredGraph(*this); + + // Graph verifications + { + assert(ir::verifier::InputOutputChecker().verify(_trainable_graph.graph())); + assert(ir::verifier::DAGChecker().verify(_trainable_graph.graph())); + assert(ir::verifier::EdgeChecker().verify(_trainable_graph.graph())); + } +} + +void LoweredTrainableGraph::makeLowerInfo(const compiler::BackendResolver &backend_resolver) +{ + _trainable_graph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) { + lower_info().operand.set(index, std::make_unique<OperandLowerInfo>()); + }); + + // Set operand lower info using assigned backends to operations + _trainable_graph.operations().iterate( + [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) { + auto backend = backend_resolver.getBackend(op_ind); + if (!backend) + { + throw std::runtime_error{"Fail to find backend for " + op.name() + " operation"}; + } + + auto frontend_layout = _trainable_graph.layout(); + + // The layout of each backend should be set at another place + // TODO Change setting layout of each backend at another place + auto backend_layout = backend->config()->supportLayout(op, frontend_layout); + + for (auto &&ind : op.getInputs() | ir::Remove::UNDEFINED) + { + auto &operand_li = lower_info().operand.at(ind); + operand_li.addUsePermuteFactor(PermuteFactor{backend, backend_layout}); + } + for (auto &&ind : op.getOutputs() | ir::Remove::UNDEFINED) + { + auto &operand_li = lower_info().operand.at(ind); + operand_li.addDefPermuteFactor(PermuteFactor{backend, backend_layout}); + } + lower_info().operation.set( + op_ind, std::make_unique<compiler::OperationLowerInfo>(backend, backend_layout)); + }); + + // Handle graph inputs and outputs + const auto builtin_backend = BackendManager::get().getBuiltin(); + auto factor = PermuteFactor{builtin_backend, _trainable_graph.layout()}; + for (auto &&index : _trainable_graph.getInputs() | ir::Remove::UNDEFINED) + { + auto &operand_li = lower_info().operand.at(index); + assert(operand_li.def_factors().empty()); + operand_li.addDefPermuteFactor(factor); + } + for (auto &&index : _trainable_graph.getOutputs() | ir::Remove::UNDEFINED) + { + auto &operand_li = lower_info().operand.at(index); + operand_li.addUsePermuteFactor(factor); + } + + // Handle variable tensors + _trainable_graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &operand) { + // Some inputs of an operation could be non-constant, but not existed in graph inputs/outputs + // and not undefined operand - these are variable tensors. For example, + // UnidirectionalSequenceLSTM has such inputs. + if (operand.info().isVariable()) + { + // The variable operand with buffer is not supported yet + assert(operand.data() == nullptr); + assert(operand.getUses().size() == 1 && !operand.getDef().valid()); + auto operand_li = lower_info().operand.at(index); + assert(operand_li.def_factors().empty()); + operand_li.addDefPermuteFactor(operand_li.use_factors().getOnlyElement()); + } + }); +} + +void LoweredTrainableGraph::dumpLowerInfo() +{ + if (::onert::util::logging::ctx.enabled() == false) + return; + + std::map<uint32_t, std::string> dumps; + + _trainable_graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &object) { + const auto operand_lower_info = lower_info().operand.getRawPtr(index); + assert(operand_lower_info); + if (!operand_lower_info->def_factors().empty() || !operand_lower_info->use_factors().empty()) + { + auto shape_to_string = [](const ir::Shape &shape) { + std::stringstream sstream; + sstream << "{ "; + for (auto i = 0; i < shape.rank(); ++i) + sstream << (shape.dim(i)) << " "; + sstream << "}"; + return sstream.str(); + }; + + auto factors_to_string = [](const PermuteFactorSet &factors) { + std::string str; + for (auto &&factor : factors) + { + str += factor.backend()->config()->id(); + str += "(" + to_string(factor.layout()) + ")"; + str += " "; + } + return "{ " + str + "}"; + }; + + auto operation_index_set_to_string = [](const ir::OperationIndexSet &operations) { + std::stringstream sstream; + sstream << "{ "; + for (auto &&op : operations) + sstream << op << " "; + sstream << "}"; + return sstream.str(); + }; + + auto data_to_str = [](const ir::Data *data) { + return (data ? (std::to_string(data->size()) + " bytes") : "N/A"); + }; + + std::string shape_str = shape_to_string(object.shape()); + std::string def_op = operation_index_set_to_string({object.getDef()}); + std::string use_ops = operation_index_set_to_string(object.getUses()); + std::string def_factors = factors_to_string(operand_lower_info->def_factors()); + std::string use_factors = factors_to_string(operand_lower_info->use_factors()); + std::stringstream sstream; + sstream << "Operand " << index << " Info" << std::endl; + sstream << " - Shape : " << shape_str << std::endl; + sstream << " - Def/Uses : Def " << def_op << " Uses " << use_ops << std::endl; + sstream << " - Data : " << data_to_str(object.data()) << std::endl; + sstream << " - LowerInfo : Def " << def_factors << " Uses " << use_factors << std::endl; + dumps.emplace(index.value(), sstream.str()); + } + }); + + for (const auto &e : dumps) + { + if (!e.second.empty()) + { + std::istringstream iss(e.second); + std::string line; + while (std::getline(iss, line)) + VERBOSE(Lower) << line << std::endl; + } + } +} + +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.cc b/runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.cc new file mode 100644 index 000000000..eae8cdeef --- /dev/null +++ b/runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.cc @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "StaticBackwardShapeInferer.h" +#include "util/ShapeInference.h" +#include "util/logging.h" + +#include <misc/polymorphic_downcast.h> + +#include <sstream> +#include <stdexcept> + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +void StaticBackwardShapeInferer::infer() +{ + // It is not determined to iterate in reverse order. + auto sorted_ops = _lowered_subg->graph().topolSortOperations(); + for (auto it = sorted_ops.rbegin(); it != sorted_ops.rend(); ++it) + { + const auto op_idx = *it; + const auto &op = _lowered_subg->trainable_graph().operation(op_idx); + if (checkDynamicInput(op)) + { + std::stringstream msg; + msg << "StaticBackwardShapeInferer does not support dynamic shape yet, "; + msg << op.name() << "(op index: " << op_idx << ") has dynamic shape."; + throw std::runtime_error(msg.str()); + } + + checkOutput(op); + + op.accept(*this); + } +} + +void StaticBackwardShapeInferer::dump() +{ + // TODO dump +} + +bool StaticBackwardShapeInferer::checkDynamicInput(const ir::IOperation &op) +{ + const auto &operands = _lowered_subg->graph().operands(); + for (const auto &input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED) + { + if (operands.at(input_idx).info().isDynamic()) + { + return true; + } + } + + return false; +} + +void StaticBackwardShapeInferer::checkOutput(const ir::IOperation &op) +{ + const auto &bwd_operands = _lowered_subg->trainable_graph().backward_operands(); + for (const auto &output_idx : op.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED) + { + if (!bwd_operands.exist(output_idx)) + { + std::stringstream msg; + msg << "StaticBackwardShapeInferer : Invalid output, "; + msg << op.name() << "'s back propagation output(index: " << output_idx << ") does not exist."; + throw std::runtime_error(msg.str()); + } + } +} + +void StaticBackwardShapeInferer::setShape(const ir::OperandIndex &index, const ir::Shape &shape) +{ + auto &tgraph = _lowered_subg->trainable_graph(); + + if (tgraph.backward_operands().exist(index)) + tgraph.changeBackwardShape(index, shape); + else + { + // NOTE This code assumes the types are always the same, but I'm not sure. + const auto &type = tgraph.operands().at(index).typeInfo(); + const auto new_index = + tgraph.addBackwardOperand(index, std::make_unique<ir::Operand>(shape, type)); + assert(new_index == index); + UNUSED_RELEASE(new_index); + } +} + +void StaticBackwardShapeInferer::visit(const ir::train::operation::Conv2D &) +{ + // NYI +} + +void StaticBackwardShapeInferer::visit(const ir::train::operation::ElementwiseActivation &) +{ + // NYI +} + +void StaticBackwardShapeInferer::visit(const ir::train::operation::Loss &) +{ + // NYI +} + +void StaticBackwardShapeInferer::visit(const ir::train::operation::Permute &op) +{ + const auto &bwd_operands = _lowered_subg->trainable_graph().backward_operands(); + + const auto &output_idx = op.getOutputs().at(0); + const auto &output = bwd_operands.at(output_idx); + + // re-sizing shape of back propagatation input + const auto &input_idx = op.getInputs().at(0); + const auto &new_shape = output.info().shape(); + setShape(input_idx, new_shape); +} + +void StaticBackwardShapeInferer::visit(const ir::train::operation::Pool2D &) +{ + // NYI +} + +void StaticBackwardShapeInferer::visit(const ir::train::operation::Reshape &) +{ + // NYI +} + +void StaticBackwardShapeInferer::visit(const ir::train::operation::Softmax &) +{ + // NYI +} + +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.h b/runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.h new file mode 100644 index 000000000..2ad9bca5e --- /dev/null +++ b/runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_STATIC_BACKWARD_SHAPE_INFERER_H__ +#define __ONERT_COMPILER_TRAIN_STATIC_BACKWARD_SHAPE_INFERER_H__ + +#include "ir/train/TrainableOperationVisitor.h" + +#include "compiler/train/LoweredTrainableGraph.h" +#include "ir/Index.h" + +#include <memory> +#include <unordered_map> + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +/** + * @brief Class to infer shape before running kernels. It does the following: + * - re-calculate and set output shape at compile time (before running kernels) + * - if calculation cannot be done at compile time, mark the outputs to be dynamic, meaning + * shapes of outputs will be calculated during running kernels + */ +class StaticBackwardShapeInferer : public ir::train::TrainableOperationVisitor +{ +public: + StaticBackwardShapeInferer(compiler::train::LoweredTrainableGraph *lowered_subg) + : _lowered_subg{lowered_subg} + { + } + + /** + * @brief Infer shape of operands belonging to ops and set the output shape. + * If output shape cannot be known without running op, mark it so that it can be allocated + * when running kernel. + */ + void infer(void); + + void dump(); + +private: + bool checkDynamicInput(const ir::IOperation &op); + void checkOutput(const ir::IOperation &op); + void setShape(const ir::OperandIndex &index, const ir::Shape &shape); + +private: + void visit(const ir::train::operation::Conv2D &op) override; + void visit(const ir::train::operation::ElementwiseActivation &op) override; + void visit(const ir::train::operation::Loss &op) override; + void visit(const ir::train::operation::Permute &op) override; + void visit(const ir::train::operation::Pool2D &op) override; + void visit(const ir::train::operation::Reshape &op) override; + void visit(const ir::train::operation::Softmax &op) override; + +private: + compiler::train::LoweredTrainableGraph *_lowered_subg; +}; + +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_STATIC_BACKWARD_SHAPE_INFERER_H__ diff --git a/runtime/onert/core/src/compiler/train/TensorRegistries.h b/runtime/onert/core/src/compiler/train/TensorRegistries.h new file mode 100644 index 000000000..8886c9bd4 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/TensorRegistries.h @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__ +#define __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__ + +#include "../../backend/builtin/Config.h" +#include "../../backend/builtin/train/TensorRegistry.h" + +#include <backend/train/ITensorRegistry.h> +#include <backend/train/TrainableBackendContext.h> + +#include <memory> +#include <unordered_set> + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +class TensorRegistries +{ +public: + TensorRegistries() = default; + + TensorRegistries(const backend::train::TrainableBackendContexts &backend_contexts, + bool include_builtin) + { + for (const auto &e : backend_contexts) + { + auto tensor_reg = e.second->tensor_registry(); + if (e.first->config()->id() == backend::builtin::Config::ID) + { + _builtin_tensor_reg = + std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>(tensor_reg); + if (include_builtin) + _tensor_regs.insert(tensor_reg); + } + else + { + _tensor_regs.insert(tensor_reg); + } + } + } + + std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator begin() const + { + return _tensor_regs.cbegin(); + } + std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator end() const + { + return _tensor_regs.cend(); + } + + std::shared_ptr<backend::builtin::train::TensorRegistry> getBuiltinTensorRegistry() const + { + return _builtin_tensor_reg; + } + + backend::ITensor *getITensor(ir::OperandIndex index) const + { + for (const auto &tensor_reg : _tensor_regs) + { + auto tensor = tensor_reg->getITensor(index); + if (tensor) + return tensor; + } + return nullptr; + } + + backend::ITensor *getBackPropITensor(ir::OperandIndex index) const + { + for (const auto &tensor_reg : _tensor_regs) + { + auto tensor = tensor_reg->getBackPropITensor(index); + if (tensor) + return tensor; + } + return nullptr; + } + + void iterateTrainableTensors( + const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> + &fn) const + { + for (const auto &tensor_reg : _tensor_regs) + tensor_reg->iterateTrainableTensors(fn); + } + +private: + std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>> _tensor_regs; + std::shared_ptr<backend::builtin::train::TensorRegistry> _builtin_tensor_reg; +}; + +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__ diff --git a/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc new file mode 100644 index 000000000..80ed05aa5 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TrainableOperationConverter.h" + +#include "ir/train/Operations.Include.h" +#include "util/Utils.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +TrainableOperationConverter::TrainableOperationConverter( + ir::train::TrainableGraph &tgraph, const ir::train::TrainingInfo *training_info) + : UntrainableOperationConverter{tgraph}, _training_info{training_info} +{ + // Avoid unused-private-field error + UNUSED_RELEASE(_training_info); +} + +void TrainableOperationConverter::visit(const ir::operation::BinaryArithmetic &node) +{ + _return_op = std::make_unique<ir::train::operation::BinaryArithmetic>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::Conv2D &node) +{ + _return_op = std::make_unique<ir::train::operation::Conv2D>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::DepthwiseConv2D &node) +{ + _return_op = std::make_unique<ir::train::operation::DepthwiseConv2D>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::ElementwiseActivation &node) +{ + if (node.param().op_type == ir::operation::ElementwiseActivation::Type::RELU) + { + _return_op = std::make_unique<ir::train::operation::ElementwiseActivation>(node); + } + else + { + UntrainableOperationConverter::visit(node); + } +} + +void TrainableOperationConverter::visit(const ir::operation::FullyConnected &node) +{ + _return_op = std::make_unique<ir::train::operation::FullyConnected>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::Loss &node) +{ + _return_op = std::make_unique<ir::train::operation::Loss>(node, _training_info->lossInfo()); +} + +void TrainableOperationConverter::visit(const ir::operation::Pad &node) +{ + _return_op = std::make_unique<ir::train::operation::Pad>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::Permute &node) +{ + _return_op = std::make_unique<ir::train::operation::Permute>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::Pool2D &node) +{ + _return_op = std::make_unique<ir::train::operation::Pool2D>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::Reduce &node) +{ + _return_op = std::make_unique<ir::train::operation::Reduce>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::Reshape &node) +{ + _return_op = std::make_unique<ir::train::operation::Reshape>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::Softmax &node) +{ + _return_op = std::make_unique<ir::train::operation::Softmax>(node); +} + +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/TrainableOperationConverter.h b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.h new file mode 100644 index 000000000..59f92f93e --- /dev/null +++ b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_TRAINABLE_OPERATION_CONVERTER_H__ +#define __ONERT_COMPILER_TRAIN_TRAINABLE_OPERATION_CONVERTER_H__ + +#include "UntrainableOperationConverter.h" + +#include "ir/train/TrainingInfo.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +class TrainableOperationConverter : public UntrainableOperationConverter +{ +public: + TrainableOperationConverter(ir::train::TrainableGraph &trainable_graph, + const ir::train::TrainingInfo *training_info); + + using UntrainableOperationConverter::operator(); + +private: + void visit(const ir::operation::BinaryArithmetic &) override; + void visit(const ir::operation::Conv2D &) override; + void visit(const ir::operation::DepthwiseConv2D &) override; + void visit(const ir::operation::ElementwiseActivation &) override; + void visit(const ir::operation::FullyConnected &) override; + void visit(const ir::operation::Loss &node) override; + void visit(const ir::operation::Pad &node) override; + void visit(const ir::operation::Permute &node) override; + void visit(const ir::operation::Pool2D &node) override; + void visit(const ir::operation::Reduce &node) override; + void visit(const ir::operation::Reshape &) override; + void visit(const ir::operation::Softmax &) override; + +private: + const ir::train::TrainingInfo *_training_info; +}; + +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_TRAINABLE_OPERATION_CONVERTER_H__ diff --git a/runtime/onert/core/src/compiler/train/TrainingCompiler.cc b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc new file mode 100644 index 000000000..ab0de8df9 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc @@ -0,0 +1,310 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TrainingCompiler.h" + +#include "StaticBackwardShapeInferer.h" +#include "TrainableOperationConverter.h" +#include "pass/LossInsertionPass.h" +#include "../CompilerHelpers.h" +#include "../ExecutorFactory.h" +#include "../pass/ConstantOutputPass.h" +#include "../pass/OddOutputPass.h" +#include "../pass/PassRunner.h" +#include "../pass/UnusedOperandEliminationPass.h" +#include "../ShapeValidator.h" +#include "../../dumper/dot/DotDumper.h" +#include "../../exec/train/TrainableExecutors.h" +#include "../../ir/OperationDumper.h" +#include "../../ir/verifier/Verifier.h" + +#include <compiler/StaticShapeInferer.h> +#include <compiler/train/LoweredTrainableGraph.h> +#include <ir/train/TrainableGraph.h> + +#include <misc/polymorphic_downcast.h> +#include <misc/string_helpers.h> + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +TrainingCompiler::TrainingCompiler(const std::shared_ptr<ir::NNPkg> &nnpkg, CompilerOptions *copts, + const ir::train::TrainingInfo &training_info) + : _model{nnpkg->primary_model()}, _options{copts}, _training_info{training_info} +{ + if (nnpkg->model_count() > 1) + throw std::runtime_error("TrainingCompiler does not support multiple models yet"); + + if (nnpkg->primary_model()->subgraphs_count() > 1) + throw std::runtime_error("TrainingCompiler does not support multiple subgraphs yet"); +} + +std::shared_ptr<CompilerArtifact> TrainingCompiler::compile(void) +{ + /*************************************************** + * Prepare compilation phase + ***************************************************/ + if (!_options) + throw std::runtime_error{"Empty compile option"}; + + // Mode check + // TODO handle option for each model + if (_options->he_profiling_mode) + { + if (!_options->he_scheduler) + throw std::runtime_error("Heterogeneous scheduler must be enabled during profiling."); + + if (_options->executor != "Dataflow") + throw std::runtime_error("Profiling mode works only with 'Dataflow' executor"); + } + + _options->forceInternalOptions(); + _options->verboseOptions(); + + auto custom_kernel_builder = _model->getKernelBuilder(); + + _model->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) { + auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph); + // Mandatory passes + compiler::pass::PassRunner{} + .append(std::make_unique<compiler::pass::ConstantOutputPass>(subg)) + .append(std::make_unique<compiler::pass::OddOutputPass>(subg)) + .run(); + + // Optimizations + compiler::pass::PassRunner{} + .append(std::make_unique<compiler::pass::UnusedOperandEliminationPass>(subg)) + .run(); + }); + + std::unordered_map<ir::SubgraphIndex, std::shared_ptr<ir::train::TrainableGraph>> + trainable_subgraphs; + + if (_model->hasOnly<ir::Graph>()) + { + // Create trainable subgraphs by copy and converting inference model + _model->iterate([&](const ir::SubgraphIndex &subg_index, const ir::IGraph &graph) { + const auto &subg = nnfw::misc::polymorphic_downcast<const ir::Graph &>(graph); + // Create TrainableGraph by copying Graph + auto trainable_subg = std::make_shared<ir::train::TrainableGraph>(subg); + + // Convert operations to trainable operations + auto converter = TrainableOperationConverter{*trainable_subg, &_training_info}; + ir::OperationIndex min_trainable_op_idx; + subg.operations().iterate( + [&](const onert::ir::OperationIndex &op_index, const onert::ir::IOperation &op) { + auto trainable_op = converter(op); + if (_training_info.getTrainableOps().find(op_index) != + std::end(_training_info.getTrainableOps())) + { + trainable_op->enableWeightsUpdate(); + if (op_index.value() < min_trainable_op_idx.value()) + { + min_trainable_op_idx = op_index; + } + } + auto gen_index = trainable_subg->replaceOperation(op_index, std::move(trainable_op)); + UNUSED_RELEASE(gen_index); + assert(gen_index == op_index); + }); + + for (ir::OperationIndex idx{min_trainable_op_idx}; + idx.value() < trainable_subg->operations().size(); idx++) + { + trainable_subg->enableBackward(idx); + } + + trainable_subgraphs[subg_index] = std::move(trainable_subg); + }); + } + else + { + // TODO Support models that have TrainableGraphs + throw std::runtime_error("TrainingCompiler: Invalid model"); + } + + // operation + _model.reset(); + + // TODO Handle dump level for each model + auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options->graph_dump_level); + onert::dumper::dot::DotDumper dot_dumper(dump_level); + + for (const auto &pair : trainable_subgraphs) + { + const auto &subg_index = pair.first; + const auto &subg = pair.second; + dot_dumper.dump(*subg, nnfw::misc::str("before_loss_insertion-", subg_index.value())); + } + + // Apply pass for trainable subgraphs + for (auto &&pair : trainable_subgraphs) + { + auto trainable_subg = pair.second; + auto subg_index = pair.first; + + compiler::pass::PassRunner{} + .append(std::make_unique<train::pass::LossInsertionPass>(*trainable_subg, &_training_info, + subg_index)) + .run(); + } + + for (const auto &pair : trainable_subgraphs) + { + const auto &subg_index = pair.first; + const auto &subg = pair.second; + dot_dumper.dump(*subg, nnfw::misc::str("after_loss_insertion-", subg_index.value())); + } + + // Change input shape according to batch_size + for (auto &&pair : trainable_subgraphs) + { + auto trainable_subg = pair.second; + + for (const auto &ind : trainable_subg->getInputs()) + { + auto &input = trainable_subg->operands().at(ind); + auto new_shape = input.info().shape(); + // TODO Consider batch size index + if (new_shape.dim(0) != 1) + throw std::runtime_error("the first dim is not 1. It is not supported yet."); + new_shape.dim(0) = _training_info.batchSize(); + input.info().shape(new_shape); + } + } + + /*************************************************** + * Backend independent analysis & optimization phase + ***************************************************/ + // Tracing context + auto tracing_ctx = std::make_unique<util::TracingCtx>(); + + // Lower: Assign backend + std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::train::LoweredTrainableGraph>> + lowered_subgs; + { + for (auto &&pair : trainable_subgraphs) + { + auto &subg_index = pair.first; + auto trainable_subg = pair.second; + + // Lower: Assign backend + lowered_subgs[subg_index] = + std::make_unique<compiler::train::LoweredTrainableGraph>(*trainable_subg, *_options); + // Set tracing_ctx for copied graph + tracing_ctx->setSubgraphIndex(&(lowered_subgs[subg_index]->graph()), subg_index.value()); + } + } + + for (const auto &pair : lowered_subgs) + { + const auto &subg_index = pair.first; + const auto &lowered_subg = pair.second; + dot_dumper.dump(*lowered_subg, nnfw::misc::str("after_lower_subg-", subg_index.value())); + } + + // Set operands' info for back propagation as default tensor info + for (const auto &pair : lowered_subgs) + { + auto lowered_subg = pair.second.get(); + auto &tgraph = lowered_subg->trainable_graph(); + tgraph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &obj) { + if (!obj.isConstant()) + { + auto bwd_operand = std::make_unique<ir::Operand>(obj); + const auto gen_index = tgraph.addBackwardOperand(index, std::move(bwd_operand)); + assert(gen_index == index); + UNUSED_RELEASE(gen_index); + } + }); + } + + // Shape inference. + { + // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called + // recursively + std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers = + createStaticShapeInferers(lowered_subgs); + + const auto primary_subg_idx = ir::SubgraphIndex{0}; + inferers.at(primary_subg_idx)->infer(); + + for (const auto &pair_inferer : inferers) + { + const auto inferer = pair_inferer.second.get(); + inferer->dump(); + } + + // NOTE StaticBackwardShapeInferer is allocated for each subgraph, + // so it does not support models that have controlflow operations yet. + for (auto &&pair : lowered_subgs) + { + auto &lowered_subg = pair.second; + auto inferer = std::make_unique<StaticBackwardShapeInferer>(lowered_subg.get()); + inferer->infer(); + inferer->dump(); + } + } + + // Shape validation + for (const auto &pair : lowered_subgs) + { + auto &lowered_subg = pair.second; + compiler::ShapeValidator{lowered_subg->graph()}(); + } + + // TODO Validate shapes of the tensors for back propagation + + /************************************************************* + * Backend independent analysis & optimization phase finished + *************************************************************/ + auto executors = std::make_shared<exec::train::TrainableExecutors>(); + for (auto &&pair : lowered_subgs) + { + auto const model_index = ir::ModelIndex{0}; + auto const subg_index = pair.first; + auto &lowered_subg = pair.second; + auto const indexed_ranks = lowered_subg->indexed_ranks(); + + ir::OperationDumper dumper("Executor generation of Subgraph " + + std::to_string(subg_index.value())); + lowered_subg->graph().operations().iterate( + [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); }); + + ExecutorFactoryArgs args; + args.tracing_ctx = tracing_ctx.get(); + args.options = _options; + args.model_index = model_index; + args.custom_kernel_builder = custom_kernel_builder; + auto executor = std::unique_ptr<exec::IExecutor>{ + ExecutorFactory::get().create(std::move(lowered_subg), executors, args, _training_info)}; + executor->setIndexedRanks(indexed_ranks); + executors->emplace(model_index, subg_index, std::move(executor)); + } + + /******************************** + * Code generation phase finished + ********************************/ + return std::make_shared<CompilerArtifact>(executors, std::move(tracing_ctx)); +} + +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/TrainingCompiler.h b/runtime/onert/core/src/compiler/train/TrainingCompiler.h new file mode 100644 index 000000000..ab62c0f34 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/TrainingCompiler.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file TrainingCompiler.h + * @brief This file contains TrainingCompiler class to define and run compilation phase + */ + +#ifndef __ONERT_COMPILER_TRAIN_TRAINING_COMPILER_H_ +#define __ONERT_COMPILER_TRAIN_TRAINING_COMPILER_H_ + +#include "compiler/CompilerOptions.h" +#include "compiler/ICompiler.h" +#include "ir/NNPkg.h" +#include "ir/train/TrainingInfo.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +/** + * @brief Class to compile NN package + */ +class TrainingCompiler : public ICompiler +{ +public: + /** + * @brief Construct a new TrainingCompiler object for an nnpkg + * @param[in] nnpkg nnpkg to compile + * @param[in] copts compiler options + * @param[in] training_info training information + */ + explicit TrainingCompiler(const std::shared_ptr<ir::NNPkg> &nnpkg, CompilerOptions *copts, + const ir::train::TrainingInfo &training_info); + + /** + * @brief Construct a TrainingCompiler object + * + */ + TrainingCompiler(void) = delete; + + /** + * @brief Destroy the TrainingCompiler object + */ + ~TrainingCompiler() = default; + +public: + /** + * @brief Do compilation with the options + * + * @return std::shared_ptr<CompilerArtifact> Executors as a result of compilation + */ + std::shared_ptr<CompilerArtifact> compile(void); + +private: + std::shared_ptr<ir::Model> _model; + CompilerOptions *_options; + const ir::train::TrainingInfo _training_info; +}; + +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_TRAINING_COMPILER_H_ diff --git a/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc new file mode 100644 index 000000000..22f7604b5 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "UntrainableOperationConverter.h" + +#include "ir/train/operation/UntrainableOperation.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +UntrainableOperationConverter::UntrainableOperationConverter(ir::train::TrainableGraph &tgraph) + : _tgraph{tgraph}, _return_op{nullptr} +{ +} + +std::unique_ptr<ir::train::ITrainableOperation> +UntrainableOperationConverter::operator()(const ir::IOperation &op) +{ + op.accept(*this); + + return std::move(_return_op); +} + +#define OP(InternalName) \ + void UntrainableOperationConverter::visit(const ir::operation::InternalName &node) \ + { \ + _return_op = \ + std::make_unique<ir::train::operation::UntrainableOperation<ir::operation::InternalName>>( \ + node); \ + } +#include "ir/Operations.lst" +#undef OP + +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h new file mode 100644 index 000000000..e960b3831 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_UNTRAINABLE_OPERATION_CONVERTER_H__ +#define __ONERT_COMPILER_TRAIN_UNTRAINABLE_OPERATION_CONVERTER_H__ + +#include "ir/Operations.Include.h" +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableGraph.h" + +#include <memory> + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +class UntrainableOperationConverter : public ir::OperationVisitor +{ +public: + UntrainableOperationConverter(ir::train::TrainableGraph &tgraph); + std::unique_ptr<ir::train::ITrainableOperation> operator()(const ir::IOperation &op); + +#define OP(InternalName) void visit(const ir::operation::InternalName &node); +#include "ir/Operations.lst" +#undef OP + +protected: + ir::train::TrainableGraph &_tgraph; + std::unique_ptr<ir::train::ITrainableOperation> _return_op; +}; + +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_UNTRAINABLE_OPERATION_CONVERTER_H__ diff --git a/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc new file mode 100644 index 000000000..ea1f21e30 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LossInsertionPass.h" + +#include "ir/train/TrainableGraph.h" +#include "ir/train/TrainingInfo.h" +#include "ir/train/operation/Loss.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ +namespace pass +{ + +void LossInsertionPass::run() +{ + const auto &loss_info = _training_info->lossInfo(); + + if (_trainable_graph.getOutputs().size() != 1) + { + throw std::runtime_error("LossInsertionPass: Not supported multiple outputs"); + } + + // TODO Consider SparseCategoricalCrossentropy y_true shape + // SparseCategoricalCrossentropy loss has a different y_true shape than y_pred. + + // TODO Implement Loop [0, getOutputs().size()) + // index: a loop index + const auto index = 0; + const auto &y_pred_index = _trainable_graph.getOutputs().at(index); + const auto &y_pred = _trainable_graph.operands().at(y_pred_index); + auto y_true_index = _trainable_graph.addOperand(y_pred.shape(), y_pred.typeInfo()); + ir::OperandIndexSequence inputs{y_pred_index, y_true_index}; + + ir::Shape output_shape; + if (loss_info.reduction_type == ir::train::LossReductionType::Sum || + loss_info.reduction_type == ir::train::LossReductionType::SumOverBatchSize) + { + output_shape = ir::Shape{1}; + } + else + { + throw std::runtime_error("LossInsertionPass: Not supported reduction type"); + } + + const ir::TypeInfo float_op(ir::DataType::FLOAT32); + auto output_index = _trainable_graph.addOperand(output_shape, float_op); + ir::OperandIndexSequence outputs{output_index}; + + auto loss_op = std::make_unique<ir::operation::Loss>(inputs, outputs); + auto trainable_loss_op = std::make_unique<ir::train::operation::Loss>(*loss_op, loss_info); + trainable_loss_op->enableBackward(); + + _trainable_graph.addOperation(std::move(trainable_loss_op)); + + _trainable_graph.addInput(y_true_index); + + // TODO Add loss as many as output size + _trainable_graph.addLoss(output_index, ir::IOIndex{index}); +} + +} // namespace pass +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h new file mode 100644 index 000000000..1a313fb11 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_PASS_LOSS_INSERTION_PASS_H__ +#define __ONERT_COMPILER_TRAIN_PASS_LOSS_INSERTION_PASS_H__ + +#include "Pass.h" + +#include "ir/Index.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ +namespace pass +{ + +class LossInsertionPass : public Pass +{ +public: + LossInsertionPass(ir::train::TrainableGraph &trainable_graph, + const ir::train::TrainingInfo *training_info, + const ir::SubgraphIndex &subg_index) + : Pass{trainable_graph, training_info}, _subg_index{subg_index} + { + } + +public: + std::string id() final { return "LossInsertionPass"; } + void run() final; + +private: + ir::SubgraphIndex _subg_index; +}; + +} // namespace pass +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_PASS_LOSS_INSERTION_PASS_H__ diff --git a/runtime/onert/core/src/compiler/train/pass/Pass.h b/runtime/onert/core/src/compiler/train/pass/Pass.h new file mode 100644 index 000000000..0e835e19e --- /dev/null +++ b/runtime/onert/core/src/compiler/train/pass/Pass.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_PASS_PASS_H__ +#define __ONERT_COMPILER_TRAIN_PASS_PASS_H__ + +#include "../../pass/IPass.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +class TrainableGraph; +class TrainingInfo; +} // namespace train +} // namespace ir +} // namespace onert + +namespace onert +{ +namespace compiler +{ +namespace train +{ +namespace pass +{ + +class Pass : public compiler::pass::IPass +{ +public: + Pass(ir::train::TrainableGraph &trainable_graph, const ir::train::TrainingInfo *training_info) + : _trainable_graph{trainable_graph}, _training_info{training_info} + { + } + virtual ~Pass() = default; + +protected: + ir::train::TrainableGraph &_trainable_graph; + const ir::train::TrainingInfo *_training_info; +}; + +} // namespace pass +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_PASS_PASS_H__ diff --git a/runtime/onert/core/src/dumper/dot/DotBuilder.cc b/runtime/onert/core/src/dumper/dot/DotBuilder.cc index 38a69696e..9257434fa 100644 --- a/runtime/onert/core/src/dumper/dot/DotBuilder.cc +++ b/runtime/onert/core/src/dumper/dot/DotBuilder.cc @@ -29,31 +29,12 @@ DotBuilder::DotBuilder() {} void DotBuilder::update(const Node &node_info) { add(node_info); - for (auto edge : node_info.out_edges()) + for (auto &&edge : node_info.out_edges()) { addEdge(node_info, *edge); } } -void DotBuilder::addOpSequence(const DotSubgraphInfo &subgraph_info) -{ - _dot << "subgraph cluster_" << subgraph_info.index().value() << " {\n"; - _dot << " label=\"" << subgraph_info.label() << "\";\n"; - _dot << " style=filled;\n"; - _dot << " color=lightgrey;\n"; - _dot << " "; - for (auto op : subgraph_info.operations()) - { - _dot << "operation" << op.value() << "; "; - } - for (auto op : subgraph_info.operands()) - { - _dot << "operand" << op.value() << "; "; - } - _dot << "\n"; - _dot << "}\n"; -} - void DotBuilder::writeDot(std::ostream &os) { os << "digraph D {\n" @@ -66,7 +47,7 @@ void DotBuilder::add(const Node &node) _dot << node.id(); std::stringstream ss; _dot << "["; - for (auto attr : node.attributes()) + for (auto &&attr : node.attributes()) { _dot << attr.first << "=\"" << attr.second << "\" "; } diff --git a/runtime/onert/core/src/dumper/dot/DotBuilder.h b/runtime/onert/core/src/dumper/dot/DotBuilder.h index 681cbbf5d..30f32f8f9 100644 --- a/runtime/onert/core/src/dumper/dot/DotBuilder.h +++ b/runtime/onert/core/src/dumper/dot/DotBuilder.h @@ -25,7 +25,6 @@ #include "OperationNode.h" #include "OperandNode.h" -#include "DotSubgraphInfo.h" using Operation = onert::ir::Operation; using Object = onert::ir::Operand; @@ -44,7 +43,6 @@ public: public: void update(const Node &dotinfo); - void addOpSequence(const DotSubgraphInfo &subgraph_info); void writeDot(std::ostream &os); diff --git a/runtime/onert/core/src/dumper/dot/DotDumper.cc b/runtime/onert/core/src/dumper/dot/DotDumper.cc index 118057f09..98524d8d1 100644 --- a/runtime/onert/core/src/dumper/dot/DotDumper.cc +++ b/runtime/onert/core/src/dumper/dot/DotDumper.cc @@ -19,8 +19,7 @@ #include "DotDumper.h" #include "DotBuilder.h" -#include "DotSubgraphInfo.h" -#include "ir/OpSequence.h" +#include "ir/OperandIndexMap.h" #include "ir/OperationIndexMap.h" #include "backend/Backend.h" #include "backend/IConfig.h" @@ -33,151 +32,153 @@ namespace dumper namespace dot { -void DotDumper::dump(const std::string &tag) +namespace { - if (_level == Level::OFF) - { - return; - } - - onert::dumper::dot::DotBuilder dot_builder; - - auto &operations = _graph.operations(); - auto &operands = _graph.operands(); - - ir::OperationIndexMap<std::unique_ptr<Operation>> operation_nodes; - std::unordered_map<ir::OperandIndex, std::unique_ptr<Operand>> operand_nodes; - - auto backend_to_fillcolor = [](const backend::Backend *backend) { - static const auto map = []() { - std::unordered_map<const backend::Backend *, std::string> ret; - uint32_t index = 1; // Start from 1 to avoid 0(red) which is too dark :( - for (const auto backend : compiler::BackendManager::get().getAll()) - { - ret.emplace(backend, Node::BG_COLORS[index]); - index = (index + 1) % (sizeof(Node::BG_COLORS) / sizeof(Node::BG_COLORS[0])); - } - return ret; - }(); - - auto itr = map.find(backend); - if (itr == map.end()) - { - return Node::DEFAULT_FILLCOLOR; - } - else +std::string backend_to_fillcolor(const backend::Backend *backend) +{ + static const auto map = []() { + std::unordered_map<const backend::Backend *, std::string> ret; + uint32_t index = 1; // Start from 1 to avoid 0(red) which is too dark :( + for (const auto backend : compiler::BackendManager::get().getAll()) { - return itr->second; + ret.emplace(backend, Node::BG_COLORS[index]); + index = (index + 1) % (sizeof(Node::BG_COLORS) / sizeof(Node::BG_COLORS[0])); } - }; + return ret; + }(); + auto itr = map.find(backend); + if (itr == map.end()) + { + return Node::DEFAULT_FILLCOLOR; + } + else + { + return itr->second; + } +} - util::Set<ir::OperandIndex> shown_operand_set; +std::unordered_map<ir::OperandIndex, std::unique_ptr<Operand>> +generate_dot_operands(const ir::Graph &graph, const DotDumper::Level level) +{ + std::unordered_map<ir::OperandIndex, std::unique_ptr<Operand>> dot_operands; + const auto &operands = graph.operands(); operands.iterate([&](const ir::OperandIndex &index, const ir::Operand &object) { - bool showing_cond = false; - if (_level == Level::ALL) - { - showing_cond = true; - } - else - { - showing_cond = !object.isConstant(); - } - if (object.isConstant() || _graph.getInputs().contains(index)) - { - showing_cond = showing_cond && (object.getUses().size() > 0); - } + bool showing_cond = + level == DotDumper::Level::ALL + ? true + : !object.isConstant() || (graph.getInputs() + graph.getOutputs()).contains(index); if (showing_cond) { - shown_operand_set.add(index); - auto type = [&]() { using onert::dumper::dot::Operand; - if (_graph.getInputs().contains(index)) + if (graph.getInputs().contains(index)) return Operand::Type::MODEL_INPUT; - if (_graph.getOutputs().contains(index)) + if (graph.getOutputs().contains(index)) return Operand::Type::MODEL_OUTPUT; return Operand::Type::INTERNAL; }(); auto node = std::make_unique<Operand>(index, type); + std::string label = std::to_string(index.value()); + std::string fillcolor = ""; + node->setAttribute("label", label); + node->setAttribute("fillcolor", fillcolor); - { - // Display LowerInfo attributes - std::string label = std::to_string(index.value()); - std::string fillcolor = ""; - if (_lowered_graph) - { - auto lower_info = _lowered_graph->getLowerInfo(index); - const auto &def_factors = lower_info->def_factors(); - if (def_factors.size() > 0) - { - label += "\\n["; - label += def_factors.getOnlyElement().backend()->config()->id(); - label += "]"; - - fillcolor = backend_to_fillcolor(lower_info->def_factors().getOnlyElement().backend()); - } - } - node->setAttribute("label", label); - node->setAttribute("fillcolor", fillcolor); - } - - operand_nodes.emplace(index, std::move(node)); + dot_operands.emplace(index, std::move(node)); } }); - operations.iterate([&](const ir::OperationIndex &index, const ir::Operation &op) { + return dot_operands; +} + +ir::OperationIndexMap<std::unique_ptr<Operation>> +generate_dot_operations(const ir::Graph &graph, + const ir::OperandIndexMap<std::unique_ptr<Operand>> &dot_operands) +{ + ir::OperationIndexMap<std::unique_ptr<Operation>> dot_operations; + const auto &operations = graph.operations(); + operations.iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) { auto node = std::make_unique<Operation>(index, op); - for (auto input : op.getInputs()) + for (auto &&input : op.getInputs()) { using onert::dumper::dot::Operand; // Constant input and dump level is ALL_BUT_CONSTANTS - if (operand_nodes.find(input) == operand_nodes.end()) + if (dot_operands.find(input) == dot_operands.end()) continue; - auto &input_node = operand_nodes.at(input); + auto &input_node = dot_operands.at(input); input_node->addOutEdge(node.get()); } - for (auto output : op.getOutputs()) + for (auto &&output : op.getOutputs() | ir::Remove::UNDEFINED) { using onert::dumper::dot::Operand; - auto &output_node = operand_nodes.at(output); + auto &output_node = dot_operands.at(output); node->addOutEdge(output_node.get()); } - operation_nodes.emplace(index, std::move(node)); + dot_operations.emplace(index, std::move(node)); }); - if (_lowered_graph) - { - const auto &op_seqs = _lowered_graph->op_seqs(); - op_seqs.iterate([&](const ir::OpSequenceIndex &index, const ir::OpSequence &op_seq) { - const auto lower_info = _lowered_graph->getLowerInfo(index); + return dot_operations; +} + +void update_lower_info(const compiler::ILoweredGraph &lowered_graph, + ir::OperandIndexMap<std::unique_ptr<Operand>> *dot_operands) +{ + const auto &operands = lowered_graph.graph().operands(); + operands.iterate([&](const ir::OperandIndex &index, const ir::Operand &) { + auto itr = dot_operands->find(index); + if (itr != dot_operands->end()) + { + auto &node = itr->second; + // Display LowerInfo attributes + std::string label = node->getAttribute("label"); + std::string fillcolor = node->getAttribute("fillcolor"); + auto lower_info = lowered_graph.lower_info().operand.getRawPtr(index); + const auto &def_factors = lower_info->def_factors(); + if (def_factors.size() > 0) + { + label += "\\n["; + label += def_factors.getOnlyElement().backend()->config()->id(); + label += "]"; + fillcolor = backend_to_fillcolor(lower_info->def_factors().getOnlyElement().backend()); + } + node->setAttribute("label", label); + node->setAttribute("fillcolor", fillcolor); + } + }); +} + +void update_lower_info(const compiler::ILoweredGraph &lowered_graph, + ir::OperationIndexMap<std::unique_ptr<Operation>> *dot_operations) +{ + const auto &operations = lowered_graph.graph().operations(); + operations.iterate([&](const ir::OperationIndex &index, const ir::IOperation &) { + const auto lower_info = lowered_graph.lower_info().operation.getRawPtr(index); + if (lower_info) + { auto fillcolor = backend_to_fillcolor(lower_info->backend()); - std::string label = - std::to_string(index.value()) + " [" + lower_info->backend()->config()->id() + "]"; - DotSubgraphInfo subgraph_info{index, op_seq, shown_operand_set, _graph.operations()}; - subgraph_info.label(label); - subgraph_info.fillcolor(fillcolor); - dot_builder.addOpSequence(subgraph_info); - - // Set fillcolor of all operations in the op_seq - for (const auto &op_idx : op_seq.operations()) + std::string backend_label = "[" + lower_info->backend()->config()->id() + "]"; + auto itr = dot_operations->find(index); + if (itr != dot_operations->end()) { - auto found = operation_nodes.find(op_idx); - if (found != operation_nodes.end()) - { - auto &&op = found->second; - op->setAttribute("fillcolor", fillcolor); - } + auto &node = itr->second; + node->setAttribute("label", node->getAttribute("label") + "\n" + backend_label); + node->setAttribute("fillcolor", fillcolor); } - }); - } + } + }); +} +void dump_to_file(const ir::OperandIndexMap<std::unique_ptr<Operand>> &operand_nodes, + const ir::OperationIndexMap<std::unique_ptr<Operation>> &operation_nodes, + const std::string &tag) +{ + onert::dumper::dot::DotBuilder dot_builder; for (const auto &e : operation_nodes) dot_builder.update(*e.second); for (const auto &e : operand_nodes) @@ -198,6 +199,39 @@ void DotDumper::dump(const std::string &tag) fb.close(); } } +} // namespace + +void DotDumper::dump(const ir::Graph &graph, const std::string &tag) +{ + if (_level == Level::OFF) + { + return; + } + + const auto dot_operands = generate_dot_operands(graph, _level); + const auto dot_operations = generate_dot_operations(graph, dot_operands); + dump_to_file(dot_operands, dot_operations, tag); +} + +// TODO Support tensors for training +void DotDumper::dump(const compiler::ILoweredGraph &lowered_graph, const std::string &tag) +{ + if (_level == Level::OFF) + { + return; + } + + auto dot_operands = generate_dot_operands(lowered_graph.graph(), _level); + auto dot_operations = generate_dot_operations(lowered_graph.graph(), dot_operands); + update_lower_info(lowered_graph, &dot_operands); + update_lower_info(lowered_graph, &dot_operations); + dump_to_file(dot_operands, dot_operations, tag); +} + +void DotDumper::dump(const ir::train::TrainableGraph &graph, const std::string &tag) +{ + dump(graph.graph(), tag); +} } // namespace dot } // namespace dumper diff --git a/runtime/onert/core/src/dumper/dot/DotDumper.h b/runtime/onert/core/src/dumper/dot/DotDumper.h index fdbca1642..59f4b3bda 100644 --- a/runtime/onert/core/src/dumper/dot/DotDumper.h +++ b/runtime/onert/core/src/dumper/dot/DotDumper.h @@ -15,7 +15,8 @@ */ #include "ir/Graph.h" -#include "compiler/LoweredGraph.h" +#include "ir/train/TrainableGraph.h" +#include "compiler/ILoweredGraph.h" #ifndef __ONERT_DUMPER_DOT_DOT_DUMPER_H__ #define __ONERT_DUMPER_DOT_DOT_DUMPER_H__ @@ -38,27 +39,37 @@ public: }; public: - DotDumper(const ir::Graph &graph, Level level) - : _lowered_graph{nullptr}, _graph(graph), _level{level} - { - } - DotDumper(const compiler::LoweredGraph *lowered_graph, Level level) - : _lowered_graph{lowered_graph}, _graph(_lowered_graph->graph()), _level{level} - { - } + DotDumper(Level level) : _level{level} {} public: /** - * @brief Dump to dot file as tag name if "GRAPH_DOT_DUMP" is set + * @brief Dump graph information to dot file as tag name if "GRAPH_DOT_DUMP" is set * + * @param[in] graph The graph that would be used to get operations and operands * @param[in] tag The name of dot file that would be created * @return N/A */ - void dump(const std::string &tag); + void dump(const ir::Graph &graph, const std::string &tag); + + /** + * @brief Dump lowered graph information to dot file as tag name if "GRAPH_DOT_DUMP" is set + * + * @param[in] graph The graph that would be used to get operations and operands + * @param[in] tag The name of dot file that would be created + * @return N/A + */ + void dump(const compiler::ILoweredGraph &lowered_graph, const std::string &tag); + + /** + * @brief Dump graph information to dot file as tag name if "GRAPH_DOT_DUMP" is set + * + * @param[in] graph TrainableGraph to be dumped + * @param[in] tag The name of dot file to be dumped + * @return N/A + */ + void dump(const ir::train::TrainableGraph &graph, const std::string &tag); private: - const compiler::LoweredGraph *_lowered_graph; - const ir::Graph &_graph; Level _level; }; diff --git a/runtime/onert/core/src/dumper/dot/DotSubgraphInfo.cc b/runtime/onert/core/src/dumper/dot/DotSubgraphInfo.cc deleted file mode 100644 index 52e9c758d..000000000 --- a/runtime/onert/core/src/dumper/dot/DotSubgraphInfo.cc +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "DotSubgraphInfo.h" - -#include <sstream> - -namespace onert -{ -namespace dumper -{ -namespace dot -{ - -DotSubgraphInfo::DotSubgraphInfo(const ir::OpSequenceIndex &index, const ir::OpSequence &op_seq, - const util::Set<ir::OperandIndex> &shown_operands, - const ir::Operations &operations_ctx) - : _index{index} -{ - for (const auto &op_idx : op_seq.operations()) - { - _operations.insert(op_idx); - const auto &node = operations_ctx.at(op_idx); - for (auto o : node.getInputs()) - { - // Must be a shown operand, not op_seq's inputs - if (shown_operands.contains(o) && !op_seq.getInputs().contains(o)) - { - _operands.insert(o); - } - } - for (auto o : node.getOutputs()) - { - // Must be a shown operand, not op_seq's inputs - if (shown_operands.contains(o) && !op_seq.getOutputs().contains(o)) - { - _operands.insert(o); - } - } - } -} - -} // namespace dot -} // namespace dumper -} // namespace onert diff --git a/runtime/onert/core/src/dumper/dot/DotSubgraphInfo.h b/runtime/onert/core/src/dumper/dot/DotSubgraphInfo.h deleted file mode 100644 index 95ba8953e..000000000 --- a/runtime/onert/core/src/dumper/dot/DotSubgraphInfo.h +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __ONERT_CORE_DUMPER_DOT_DOT_SUBGRAPH_INFO_H__ -#define __ONERT_CORE_DUMPER_DOT_DOT_SUBGRAPH_INFO_H__ - -#include <unordered_set> - -#include "ir/Index.h" -#include <ir/Operations.h> -#include "ir/OpSequence.h" -#include "util/Set.h" - -namespace onert -{ -namespace dumper -{ -namespace dot -{ - -class DotSubgraphInfo -{ -public: - DotSubgraphInfo(const ir::OpSequenceIndex &index, const ir::OpSequence &op_seq, - const util::Set<ir::OperandIndex> &shown_operands, - const ir::Operations &operations_ctx); - - ir::OpSequenceIndex index() const { return _index; } - std::string label() const { return _label; } - void label(const std::string &val) { _label = val; } - std::string fillcolor() const { return _fillcolor; } - void fillcolor(const std::string &val) { _fillcolor = val; } - const std::unordered_set<ir::OperationIndex> &operations() const { return _operations; } - const std::unordered_set<ir::OperandIndex> &operands() const { return _operands; } - -private: - ir::OpSequenceIndex _index; - std::string _label; - std::string _fillcolor; - std::unordered_set<ir::OperationIndex> _operations; - std::unordered_set<ir::OperandIndex> _operands; -}; - -} // namespace dot -} // namespace dumper -} // namespace onert - -#endif // __ONERT_CORE_DUMPER_DOT_DOT_SUBGRAPH_INFO_H__ diff --git a/runtime/onert/core/src/dumper/dot/OperandNode.cc b/runtime/onert/core/src/dumper/dot/OperandNode.cc index 5a6015ca9..cbc73878f 100644 --- a/runtime/onert/core/src/dumper/dot/OperandNode.cc +++ b/runtime/onert/core/src/dumper/dot/OperandNode.cc @@ -18,7 +18,6 @@ #include "OperandNode.h" #include "ir/Graph.h" -#include "ir/operand/LowerInfo.h" namespace onert { @@ -33,10 +32,10 @@ const std::string Operand::OPERAND_SHAPE = "ellipse"; const std::string Operand::BG_COLOR_SCHEME = "set18"; Operand::Operand(const ir::OperandIndex &index, Type type) - : Node{"operand" + std::to_string(index.value())} + : Node{"operand" + std::to_string(index.value())} { { - auto type_to_shape = [](Type type) { + auto type_to_shape = [](Type type) -> const std::string & { switch (type) { case Type::MODEL_INPUT: diff --git a/runtime/onert/core/src/dumper/dot/OperandNode.h b/runtime/onert/core/src/dumper/dot/OperandNode.h index 2e7cc5861..f2aea80ad 100644 --- a/runtime/onert/core/src/dumper/dot/OperandNode.h +++ b/runtime/onert/core/src/dumper/dot/OperandNode.h @@ -64,7 +64,6 @@ public: * * @param[in] index Operand index * @param[in] type Operand type - * @param[in] lower_info Operand LowerInfo */ Operand(const ir::OperandIndex &index, Type type); diff --git a/runtime/onert/core/src/dumper/dot/OperationNode.cc b/runtime/onert/core/src/dumper/dot/OperationNode.cc index bee137e7c..2ef08c9c6 100644 --- a/runtime/onert/core/src/dumper/dot/OperationNode.cc +++ b/runtime/onert/core/src/dumper/dot/OperationNode.cc @@ -18,7 +18,6 @@ #include "OperationNode.h" #include "ir/Graph.h" -#include "ir/operation/LowerInfo.h" #include "backend/IConfig.h" #include "backend/Backend.h" @@ -32,8 +31,8 @@ namespace dot const std::string Operation::OPERATION_SHAPE = "rect"; const std::string Operation::BG_COLOR_SCHEME = "pastel18"; -Operation::Operation(const ir::OperationIndex &index, const ir::Operation &node) - : Node{"operation" + std::to_string(index.value())} +Operation::Operation(const ir::OperationIndex &index, const ir::IOperation &node) + : Node{"operation" + std::to_string(index.value())} { setAttribute("label", std::to_string(index.value()) + " : " + node.name()); setAttribute("shape", OPERATION_SHAPE); diff --git a/runtime/onert/core/src/dumper/dot/OperationNode.h b/runtime/onert/core/src/dumper/dot/OperationNode.h index 74a37d3fb..d9292ad0c 100644 --- a/runtime/onert/core/src/dumper/dot/OperationNode.h +++ b/runtime/onert/core/src/dumper/dot/OperationNode.h @@ -25,7 +25,7 @@ #define __ONERT_DUMPER_DOT_DOT_NODE_INFO_H__ #include "Node.h" -#include "ir/Operation.h" +#include "ir/IOperation.h" #include "ir/Index.h" namespace onert @@ -52,7 +52,7 @@ public: * @param[in] index operation index * @param[in] node operation object */ - Operation(const ir::OperationIndex &index, const ir::Operation &node); + Operation(const ir::OperationIndex &index, const ir::IOperation &node); }; } // namespace dot diff --git a/runtime/onert/core/src/compiler/ParamChecker.cc b/runtime/onert/core/src/dumper/h5/Dumper.cc index c4f80f087..5e12c2dbb 100644 --- a/runtime/onert/core/src/compiler/ParamChecker.cc +++ b/runtime/onert/core/src/dumper/h5/Dumper.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,20 +14,21 @@ * limitations under the License. */ -#include "ParamChecker.h" +#include "Dumper.h" -#include "ir/Graph.h" +#include <iostream> +#include <sstream> +#include <stdexcept> namespace onert { -namespace compiler +namespace dumper { - -void ParamChecker::operator()() +namespace h5 { - _model->operations().iterate( - [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); }); -} -} // namespace compiler +Dumper::Dumper(const std::string &filepath) : _file{filepath, H5F_ACC_CREAT | H5F_ACC_RDWR} {} + +} // namespace h5 +} // namespace dumper } // namespace onert diff --git a/runtime/onert/core/src/dumper/h5/Dumper.h b/runtime/onert/core/src/dumper/h5/Dumper.h new file mode 100644 index 000000000..53d0e0332 --- /dev/null +++ b/runtime/onert/core/src/dumper/h5/Dumper.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_DUMPER_H5_DUMPER_H__ +#define __ONERT_DUMPER_H5_DUMPER_H__ + +#include "exec/MinMaxMap.h" + +#include <H5Cpp.h> +#include <string> + +namespace onert +{ +namespace dumper +{ +namespace h5 +{ + +class Dumper +{ +public: + /** + * @brief Construct dumper + * + * @param[in] path filepath to dump + * @throw H5::FileIException on error during file open/create + */ + Dumper(const std::string &filepath); + +protected: + H5::H5File _file; +}; + +} // namespace h5 +} // namespace dumper +} // namespace onert + +#endif // __ONERT_DUMPER_H5_DUMPER_H__ diff --git a/runtime/onert/core/src/dumper/h5/MinMaxDumper.cc b/runtime/onert/core/src/dumper/h5/MinMaxDumper.cc new file mode 100644 index 000000000..e353ed5cb --- /dev/null +++ b/runtime/onert/core/src/dumper/h5/MinMaxDumper.cc @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MinMaxDumper.h" + +#include <iostream> +#include <sstream> +#include <stdexcept> + +namespace onert +{ +namespace dumper +{ +namespace h5 +{ + +static const char *h5_value_grpname = "value"; + +/* + * ensure grp_name exists in parent + */ +H5::Group ensureGroup(H5::Group parent, const std::string &child) +{ + H5::Exception::dontPrint(); + try + { + return parent.openGroup(child.c_str()); + } + catch (H5::Exception &e) + { + return parent.createGroup(child.c_str()); + } +} + +MinMaxDumper::MinMaxDumper(const std::string &filepath) : Dumper(filepath) +{ + auto root_grp = _file.openGroup("/"); + ensureGroup(root_grp, h5_value_grpname); +} + +void MinMaxDumper::dump(const exec::IOMinMaxMap &input_minmax, + const exec::OpMinMaxMap &op_minmax) const +{ + auto val_grp = _file.openGroup(h5_value_grpname); + auto num_run = val_grp.getNumObjs(); + auto run_grp = val_grp.createGroup(std::string("run_") + std::to_string(num_run)); + auto model_grp = ensureGroup(run_grp, std::string("model_") + "0"); + hsize_t dims[] = {2}; + H5::DataSpace dspace(1, dims); // rank=1, dim(0)=2, {min, max} + for (auto &&e : input_minmax) + { + // key = {subg_idx, io_idx} = e.first + const auto subg_idx = e.first.first.value(); + const auto io_idx = e.first.second.value(); + auto subg_grp = ensureGroup(model_grp, std::string("subg_") + std::to_string(subg_idx)); + auto input_dset = subg_grp.createDataSet(std::string("input_") + std::to_string(io_idx), + H5::PredType::IEEE_F32BE, dspace); + input_dset.write(e.second.data, H5::PredType::NATIVE_FLOAT); + } + for (auto &&e : op_minmax) + { + // key = {subg_idx, op_idx} = e.first + const auto subg_idx = e.first.first.value(); + const auto op_idx = e.first.second.value(); + auto subg_grp = ensureGroup(model_grp, std::string("subg_") + std::to_string(subg_idx)); + auto op_dset = subg_grp.createDataSet(std::string("op_") + std::to_string(op_idx), + H5::PredType::IEEE_F32BE, dspace); + op_dset.write(e.second.data, H5::PredType::NATIVE_FLOAT); + } +} + +} // namespace h5 +} // namespace dumper +} // namespace onert diff --git a/runtime/onert/core/src/dumper/h5/MinMaxDumper.h b/runtime/onert/core/src/dumper/h5/MinMaxDumper.h new file mode 100644 index 000000000..d7e2c1c31 --- /dev/null +++ b/runtime/onert/core/src/dumper/h5/MinMaxDumper.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_DUMPER_H5_MINMAX_DUMPER_H__ +#define __ONERT_DUMPER_H5_MINMAX_DUMPER_H__ + +#include "exec/MinMaxMap.h" +#include "Dumper.h" + +#include <H5Cpp.h> +#include <string> + +namespace onert +{ +namespace dumper +{ +namespace h5 +{ + +// The hierachy of single model minmax h5 file +// +// GROUP / +// GROUP value +// └── GROUP run_{idx} +// └── GROUP model_{idx} +// └── GROUP subg_{idx} +// ├── DATASET op_{idx} +// │ DATATYPE Float32 +// │ DATASPACE (2) +// │ DATA { min, max } +// └── DATASET input_{idx} +// DATATYPE Float32 +// DATASPACE (2) +// DATA { min, max } +// GROUP name (optional, for debug) +// └── GROUP model_{idx} +// └── GROUP subg_{idx} +// ├── ATTRIBUTE op_{idx} +// │ DATATYPE String +// │ DATA { "op/name"} +// └── ATTRIBUTE input_{idx} +// DATATYPE String +// DATA { "input/name"} +// +class MinMaxDumper : private Dumper +{ +public: + MinMaxDumper(const std::string &filepath); + /** + * @brief Dump input minmax map + * + * @param[in] in_minmax input minmax map + * @param[in] op_minmax op minmax map + */ + void dump(const exec::IOMinMaxMap &in_minmax, const exec::OpMinMaxMap &op_minmax) const; + +private: + H5::Group _val_grp; +}; + +} // namespace h5 +} // namespace dumper +} // namespace onert + +#endif // __ONERT_DUMPER_H5_MINMAX_DUMPER_H__ diff --git a/runtime/onert/core/src/dumper/text/GraphDumper.cc b/runtime/onert/core/src/dumper/text/GraphDumper.cc new file mode 100644 index 000000000..c89253bda --- /dev/null +++ b/runtime/onert/core/src/dumper/text/GraphDumper.cc @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "GraphDumper.h" + +#include "ir/Graph.h" +#include "compiler/LoweredGraph.h" +#include "compiler/train/LoweredTrainableGraph.h" +#include "util/logging.h" +#include "misc/string_helpers.h" + +namespace onert +{ +namespace dumper +{ +namespace text +{ + +namespace +{ + +std::string formatOperandIndexSequence(const ir::OperandIndexSequence &seq) +{ + std::vector<std::string> strs; + for (auto &&ind : seq) + strs.push_back(dumper::text::formatOperandBrief(ind)); + return nnfw::misc::join(strs.begin(), strs.end(), ","); +} + +} // namespace + +std::string formatOperandBrief(ir::OperandIndex ind) +{ + std::stringstream ss; + ss << ind; + return ss.str(); +} + +std::string formatOperand(const ir::Graph &, ir::OperandIndex ind) +{ + std::stringstream ss; + ss << ind; + // TODO Print shape, type and maybe more + return ss.str(); +} + +std::string formatOperation(const ir::IOperation &op, ir::OperationIndex ind) +{ + std::stringstream ss; + + ss << formatOperandIndexSequence(op.getOutputs()); + ss << " = "; + ss << ind << "_" << op.name() << "("; + ss << formatOperandIndexSequence(op.getInputs()); + ss << ")"; + return ss.str(); +} + +std::string formatOperation(const ir::Graph &graph, ir::OperationIndex ind) +{ + std::stringstream ss; + const auto &op = graph.operations().at(ind); + return formatOperation(op, ind); +} + +void dumpGraph(const ir::Graph &graph) +{ + VERBOSE(GraphDumper) << "{\n"; + auto ops_topol = graph.topolSortOperations(); + for (auto &&op_ind : ops_topol) + { + const auto &op = graph.operations().at(op_ind); + VERBOSE(GraphDumper) << " " << formatOperation(op, op_ind) << "\n"; + } + graph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &oprd) { + VERBOSE(GraphDumper) << " Origin(" << idx << "): " << oprd.originIndex() << std::endl; + }); + VERBOSE(GraphDumper) << "}\n"; +} + +void dumpLoweredGraph(const compiler::LoweredGraph &lgraph) +{ + // TODO Graph dump with backend info + dumpGraph(lgraph.graph()); +} + +void dumpLoweredGraph(const compiler::train::LoweredTrainableGraph &lgraph) +{ + // TODO Graph dump with backend info + dumpGraph(lgraph.graph()); +} + +} // namespace text +} // namespace dumper +} // namespace onert diff --git a/runtime/onert/core/src/dumper/text/GraphDumper.h b/runtime/onert/core/src/dumper/text/GraphDumper.h new file mode 100644 index 000000000..3cc13c92e --- /dev/null +++ b/runtime/onert/core/src/dumper/text/GraphDumper.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_DUMPER_TEXT_GRAPH_DUMPER_H__ +#define __ONERT_DUMPER_TEXT_GRAPH_DUMPER_H__ + +#include <ir/Index.h> + +namespace onert +{ +namespace ir +{ +class Graph; +struct IOperation; +} // namespace ir +} // namespace onert + +namespace onert +{ +namespace compiler +{ +class LoweredGraph; + +namespace train +{ +class LoweredTrainableGraph; +} // namespace train +} // namespace compiler +} // namespace onert + +namespace onert +{ +namespace dumper +{ +namespace text +{ + +std::string formatOperandBrief(ir::OperandIndex ind); +std::string formatOperand(const ir::Graph &, ir::OperandIndex ind); +std::string formatOperation(const ir::Graph &graph, ir::OperationIndex ind); +void dumpGraph(const ir::Graph &graph); +void dumpLoweredGraph(const compiler::LoweredGraph &lgraph); +void dumpLoweredGraph(const compiler::train::LoweredTrainableGraph &lgraph); + +} // namespace text +} // namespace dumper +} // namespace onert + +#endif // __ONERT_DUMPER_TEXT_GRAPH_DUMPER_H__ diff --git a/runtime/onert/core/src/exec/DataflowExecutor.cc b/runtime/onert/core/src/exec/DataflowExecutor.cc index a69ae9cdb..50984cefc 100644 --- a/runtime/onert/core/src/exec/DataflowExecutor.cc +++ b/runtime/onert/core/src/exec/DataflowExecutor.cc @@ -54,14 +54,13 @@ void DataflowExecutor::emplaceToReadyJobs(const uint32_t &id) { auto &job = _waiting_jobs[id]; assert(job != nullptr); - auto &op_seq = _lowered_graph->op_seqs().at(_job_to_op_seq[job->index()]); - auto rank = calculateRank(op_seq.operations()); + auto rank = calculateRank({_job_to_op[job->index()]}); _ready_jobs.emplace(rank, std::move(job)); } void DataflowExecutor::notify(uint32_t finished_job_id) { - for (auto id : _output_info[finished_job_id]) + for (auto &&id : _output_info[finished_job_id]) { assert(_input_info[id] > 0); auto count = --_input_info[id]; @@ -77,57 +76,54 @@ bool DataflowExecutor::noWaitingJobs() [](const std::unique_ptr<Job> &job) { return job == nullptr; }); } -DataflowExecutor::DataflowExecutor( - std::unique_ptr<compiler::LoweredGraph> lowered_graph, - const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors, - const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors, - const compiler::TensorRegistries &tensor_regs, backend::TensorManagerSet &&tensor_mgrs, - compiler::CodeMap &&code_map) - : ExecutorBase{std::move(lowered_graph), input_tensors, output_tensors, tensor_regs, - std::move(tensor_mgrs)}, - _code_map{std::move(code_map)} +DataflowExecutor::DataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, + backend::BackendContexts &&backend_contexts, + const compiler::TensorRegistries &tensor_regs, + compiler::CodeMap &&code_map, + const util::TracingCtx *tracing_ctx) + : ExecutorBase{std::move(lowered_graph), std::move(backend_contexts), tensor_regs, tracing_ctx}, + _code_map{std::move(code_map)} { VERBOSE(DataflowExecutor) << "Constructing Dataflow Executor" << std::endl; - const auto &op_seqs = _lowered_graph->op_seqs(); - // Assign jobs convert OpSequenceIndex to job index(uint32_t) + // Assign jobs convert OperationIndex to job index(uint32_t) uint32_t next_job_index = 0; - std::unordered_map<ir::OpSequenceIndex, uint32_t> op_seq_to_job; - op_seqs.iterate([&](const ir::OpSequenceIndex &op_seq_index, const ir::OpSequence &) { - VERBOSE(DataflowExecutor) << "Create a job #" << next_job_index << " with OpSequenceIndex " - << op_seq_index.value() << std::endl; + std::unordered_map<ir::OperationIndex, uint32_t> op_to_job; + const auto &operations = _lowered_graph->graph().operations(); + operations.iterate([&](const ir::OperationIndex &op_ind, const ir::IOperation &) { + VERBOSE(DataflowExecutor) << "Create a job " << next_job_index << " with Operation " << op_ind + << std::endl; _finished_jobs.emplace_back( - std::make_unique<Job>(next_job_index, _code_map.at(op_seq_index).fn_seq.get())); - op_seq_to_job[op_seq_index] = next_job_index++; + std::make_unique<Job>(next_job_index, _code_map.at(op_ind).fn_seq.get())); + op_to_job[op_ind] = next_job_index++; }); _waiting_jobs.resize(next_job_index); _output_info.resize(next_job_index); _initial_input_info.resize(next_job_index, 0); - op_seqs.iterate([&](const ir::OpSequenceIndex &op_seq_index, const ir::OpSequence &op_seq) { - auto job_index = op_seq_to_job[op_seq_index]; - for (auto output : op_seq.getOutputs()) + operations.iterate([&](const ir::OperationIndex &op_ind, const ir::IOperation &op) { + auto job_index = op_to_job[op_ind]; + for (auto &&output : op.getOutputs()) { // Update output and input info - op_seqs.iterate( - [&](const ir::OpSequenceIndex &op_seq_cur_index, const ir::OpSequence &op_seq_cur) { - if (op_seq_cur.getInputs().contains(output)) - { - auto dep_index = op_seq_to_job[op_seq_cur_index]; - ++_initial_input_info[dep_index]; - _output_info[job_index].push_back(dep_index); - } - }); + operations.iterate([&](const ir::OperationIndex &op_cur_ind, const ir::IOperation &op_cur) { + if (op_cur.getInputs().contains(output)) + { + auto dep_index = op_to_job[op_cur_ind]; + ++_initial_input_info[dep_index]; + _output_info[job_index].push_back(dep_index); + } + }); } }); - for (const auto &s : op_seq_to_job) - _job_to_op_seq.emplace(s.second, s.first); + for (const auto &s : op_to_job) + _job_to_op.emplace(s.second, s.first); _input_info = _initial_input_info; } -void DataflowExecutor::executeImpl() +void DataflowExecutor::executeImpl(const ExecutionObservee &subject) { assert(noWaitingJobs()); @@ -145,35 +141,38 @@ void DataflowExecutor::executeImpl() } assert(!_ready_jobs.empty()); // Cannot begin if there is no initial jobs - _subject.notifyModelBegin(this); + auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_graph); + + subject.notifySubgraphBegin(profiling_subg_index); while (!_ready_jobs.empty()) { auto job = std::move((_ready_jobs.begin())->second); _ready_jobs.erase(_ready_jobs.begin()); auto job_index = job->index(); - VERBOSE(DataflowExecutor) << "Run job #" << job_index << std::endl; + VERBOSE(DataflowExecutor) << "Run job " << job_index << std::endl; + + auto op_ind = _job_to_op[job_index]; + const backend::Backend *backend = _lowered_graph->lower_info().operation.at(op_ind).backend(); - auto op_seq_index = _job_to_op_seq[job_index]; - auto op_seq = &_lowered_graph->op_seqs().at(op_seq_index); - const backend::Backend *backend = - _lowered_graph->getLowerInfo()->op_seq.at(op_seq_index)->backend(); + subject.notifyJobBegin(this, profiling_subg_index, op_ind, backend); - _subject.notifyJobBegin(this, op_seq, backend); + job->fn_seq()->initRunning(); // check if FunctionSequence needs to handle dynamic tensor - bool handle_dynamic_tensor = op_seq->has_dynamic_tensor() || dynamic_input_exists; + bool handle_dynamic_tensor = + _lowered_graph->getHasDynamicTensor(op_ind) || dynamic_input_exists; job->fn_seq()->enableDynamicShapeInferer(handle_dynamic_tensor); job->run(); - _subject.notifyJobEnd(this, op_seq, backend); + subject.notifyJobEnd(this, profiling_subg_index, op_ind, backend); notify(job_index); _finished_jobs[job_index] = std::move(job); } assert(noWaitingJobs()); - _subject.notifyModelEnd(this); + subject.notifySubgraphEnd(profiling_subg_index); // Reset input info for the next execution _input_info = _initial_input_info; diff --git a/runtime/onert/core/src/exec/DataflowExecutor.h b/runtime/onert/core/src/exec/DataflowExecutor.h index 8d60e3e4b..750dc244f 100644 --- a/runtime/onert/core/src/exec/DataflowExecutor.h +++ b/runtime/onert/core/src/exec/DataflowExecutor.h @@ -17,17 +17,17 @@ #ifndef __ONERT_EXEC_DATAFLOW_EXECUTOR_H__ #define __ONERT_EXEC_DATAFLOW_EXECUTOR_H__ -#include <list> -#include <map> -#include <unordered_map> - -#include "exec/FunctionSequence.h" +#include "ExecutorBase.h" #include "Job.h" + +#include "compiler/CodeMap.h" #include "ir/OperandIndexSequence.h" -#include "ir/Index.h" +#include "util/TracingCtx.h" + +#include <list> +#include <map> #include <memory> -#include "exec/ExecutorBase.h" -#include "compiler/CodeMap.h" +#include <unordered_map> namespace onert { @@ -47,15 +47,14 @@ public: * * @param lowered_graph LoweredGraph object * @param tensor_builders Tensor builders that are currently used - * @param code_map OpSequence and its code map + * @param code_map @c ir::Operation and its code map */ DataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, - const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors, - const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors, - const compiler::TensorRegistries &tensor_regs, - backend::TensorManagerSet &&tensor_mgrs, compiler::CodeMap &&code_map); + backend::BackendContexts &&backend_contexts, + const compiler::TensorRegistries &tensor_regs, compiler::CodeMap &&code_map, + const util::TracingCtx *tracing_ctx); - void executeImpl() override; + void executeImpl(const ExecutionObservee &subject) override; protected: int64_t calculateRank(const std::vector<ir::OperationIndex> &operations); @@ -88,7 +87,7 @@ protected: std::multimap<int64_t, std::unique_ptr<Job>, std::greater<int64_t>> _ready_jobs; /// @brief Which job runs which op and function. - std::unordered_map<uint32_t, ir::OpSequenceIndex> _job_to_op_seq; + std::unordered_map<uint32_t, ir::OperationIndex> _job_to_op; }; } // namespace exec diff --git a/runtime/onert/core/src/exec/DynamicShapeInference.cc b/runtime/onert/core/src/exec/DynamicShapeInferer.cc index 70bddfce4..691a11933 100644 --- a/runtime/onert/core/src/exec/DynamicShapeInference.cc +++ b/runtime/onert/core/src/exec/DynamicShapeInferer.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "exec/DynamicShapeInference.h" +#include "exec/DynamicShapeInferer.h" #include "util/ShapeInference.h" #include <assert.h> @@ -23,14 +23,6 @@ namespace onert namespace exec { -inline backend::IDynamicTensorManager * -dynamicTensorManagerOf(const std::shared_ptr<backend::ITensor> &tensor) -{ - if (!tensor->dynamic_tensor_manager()) - throw std::runtime_error{"Dynamic Tensor Manager is not available for this tensor."}; - return tensor->dynamic_tensor_manager(); -} - void DynamicShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op, const ir::OperandIndex lhs_idx, const ir::OperandIndex rhs_idx) @@ -56,15 +48,15 @@ void DynamicShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op, So, only when all inputs are static, we can skip dynamic shape inference. */ - if ((!lhs->is_dynamic()) && (!rhs->is_dynamic())) - return; - auto output_idx = op.getOutputs().at(0); auto output = _tensor_registry->getITensor(output_idx); + if ((currently_static(lhs) && currently_static(rhs)) && previously_static(output)) + return; + ir::Shape new_shape = shape_inference::inferEltwiseShape(lhs_shape, rhs_shape); - dynamicTensorManagerOf(output)->applyShape(output_idx, new_shape); + output->applyShape(new_shape); assert(output->buffer() != nullptr); } @@ -96,30 +88,32 @@ void DynamicShapeInferer::handleSimpleUnaryOp(const ir::Operation &op, auto output_ind = op.getOutputs().at(0); auto output = _tensor_registry->getITensor(output_ind); - dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape); + output->applyShape(output_shape); assert(output->buffer() != nullptr); } -void DynamicShapeInferer::visit(const ir::operation::ArgMax &op) +void DynamicShapeInferer::visit(const ir::operation::ArgMinMax &op) { - const auto input_idx{op.getInputs().at(ir::operation::ArgMax::Input::INPUT)}; - const auto &input = _tensor_registry->getITensor(input_idx); - auto input_shape = input->getShape(); + const auto input_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::INPUT)}; + const auto input = _tensor_registry->getITensor(input_idx); - if (!input->is_dynamic()) - return; - - const auto rank = input_shape.rank(); - const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis); - - assert(0 <= axis && axis < rank); + const auto axis_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::AXIS)}; + const auto axis = _tensor_registry->getITensor(axis_idx); auto output_ind = op.getOutputs().at(0); auto output = _tensor_registry->getITensor(output_ind); - ir::Shape new_shape = shape_inference::inferArgMaxShape(input_shape, axis, rank); + if (!input->is_dynamic() && !output->is_dynamic()) + return; - dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape); + auto input_shape = input->getShape(); + auto axis_value = *reinterpret_cast<const int32_t *>(axis->buffer()); + const auto rank = input_shape.rank(); + axis_value = axis_value < 0 ? axis_value + rank : axis_value; + + ir::Shape new_shape = shape_inference::inferArgMinMaxShape(input_shape, axis_value, rank); + + output->applyShape(new_shape); assert(output->buffer() != nullptr); } @@ -141,7 +135,68 @@ void DynamicShapeInferer::visit(const ir::operation::BatchMatMul &op) // TODO auto new_shape = shape_inference::inferBatchMatMulShape(lhs_shape, rhs_shape, op.param()); - dynamicTensorManagerOf(output)->applyShape(output_index, new_shape); + output->applyShape(new_shape); +} + +void DynamicShapeInferer::visit(const ir::operation::BCQFullyConnected &op) +{ + const auto input_idx{op.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)}; + const auto &input = _tensor_registry->getITensor(input_idx); + + const auto cluster_idx{ + op.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)}; + const auto &cluster = _tensor_registry->getITensor(cluster_idx); + assert(cluster->is_constant()); + + if (!input->is_dynamic()) + return; + + auto input_shape = input->getShape(); + auto cluster_shape = cluster->getShape(); + + auto cluster_buf = reinterpret_cast<const int32_t *>(cluster->buffer()); + assert(cluster_buf); + + ir::Shape new_shape = + shape_inference::inferBCQFullyConnectedShape(input_shape, cluster_shape, cluster_buf); + + auto output_ind = op.getOutputs().at(0); + auto output = _tensor_registry->getITensor(output_ind); + + output->applyShape(new_shape); + assert(output->buffer() != nullptr); +} + +void DynamicShapeInferer::visit(const ir::operation::BCQGather &op) +{ + const auto indices_idx{op.getInputs().at(ir::operation::BCQGather::Input::INDICES)}; + const auto &indices = _tensor_registry->getITensor(indices_idx); + + const auto input_binary_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)}; + const auto &input_binary = _tensor_registry->getITensor(input_binary_idx); + + const auto cluster_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)}; + const auto &cluster = _tensor_registry->getITensor(cluster_idx); + assert(cluster->is_constant()); + + if (!indices->is_dynamic()) + return; + + auto indices_shape = indices->getShape(); + auto cluster_shape = cluster->getShape(); + auto rank = input_binary->getShape().rank(); + + auto cluster_buf = reinterpret_cast<const int32_t *>(cluster->buffer()); + assert(cluster_buf); + + ir::Shape new_shape = shape_inference::inferBCQGatherShape(indices_shape, cluster_shape, + cluster_buf, rank, op.param()); + + auto output_ind = op.getOutputs().at(0); + auto output = _tensor_registry->getITensor(output_ind); + + output->applyShape(new_shape); + assert(output->buffer() != nullptr); } void DynamicShapeInferer::visit(const ir::operation::BinaryArithmetic &op) @@ -167,10 +222,10 @@ void DynamicShapeInferer::visit(const ir::operation::BroadcastTo &op) assert(shape); // It shouldn't be 0. auto output_shape = shape_inference::inferBroadcastToShape( - shape->getShape(), reinterpret_cast<const int32_t *>(shape->buffer())); + shape->getShape(), reinterpret_cast<const int32_t *>(shape->buffer())); // set output shape and output buffer - dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape); + output->applyShape(output_shape); assert(output->buffer() != nullptr); } @@ -198,7 +253,7 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op) So, only when all inputs are static, we can skip dynamic shape inference. */ bool all_static = true; - for (auto input_ind : op.getInputs()) + for (auto &&input_ind : op.getInputs()) { auto input = _tensor_registry->getITensor(input_ind); if (input->is_dynamic()) @@ -215,15 +270,17 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op) { auto isConcatible = [](const backend::ITensor *input1, const backend::ITensor *input2, int32_t axis) { - if (input1->num_dimensions() != input2->num_dimensions()) + auto shape1 = input1->getShape(); + auto shape2 = input2->getShape(); + if (shape1.rank() != shape2.rank()) return false; - for (size_t i = 0; i < input1->num_dimensions(); i++) + for (int i = 0; i < shape1.rank(); i++) { - auto positive_axis = (axis >= 0) ? axis : axis + input1->num_dimensions(); + auto positive_axis = (axis >= 0) ? axis : axis + input1->getShape().rank(); if (i != positive_axis) - if (input1->dimension(i) != input2->dimension(i)) + if (shape1.dim(i) != shape2.dim(i)) return false; } @@ -233,17 +290,17 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op) auto first_input_ind = op.getInputs().at(0); auto first_input = _tensor_registry->getITensor(first_input_ind); - for (auto input_ind : op.getInputs()) + for (auto &&input_ind : op.getInputs()) { auto input = _tensor_registry->getITensor(input_ind); - if (input != first_input && !isConcatible(first_input.get(), input.get(), op.param().axis)) + if (input != first_input && !isConcatible(first_input, input, op.param().axis)) throw std::runtime_error("input shapes does not matched for concat"); } } // getting output shape onert::shape_inference::Shapes in_shapes; - for (auto input_ind : op.getInputs()) + for (auto &&input_ind : op.getInputs()) { auto input = _tensor_registry->getITensor(input_ind); ir::Shape shape = input->getShape(); @@ -255,7 +312,7 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op) auto output = _tensor_registry->getITensor(output_ind); auto output_shape = shape_inference::inferConcatShape(in_shapes, op.param()); - dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape); + output->applyShape(output_shape); } void DynamicShapeInferer::visit(const ir::operation::Conv2D &op) @@ -278,7 +335,7 @@ void DynamicShapeInferer::visit(const ir::operation::Conv2D &op) ir::Shape output_shape = shape_inference::inferConv2DShape(input_shape, ker_shape, op.param()); - dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape); + output->applyShape(output_shape); assert(output->buffer() != nullptr); } @@ -333,12 +390,18 @@ void DynamicShapeInferer::visit(const ir::operation::ExpandDims &op) auto axis_ind = op.getInputs().at(ir::operation::ExpandDims::AXIS); auto axis = _tensor_registry->getITensor(axis_ind); - auto axis_buf = reinterpret_cast<const int32_t *>(axis->buffer()); - assert(axis_buf); + auto axis_type = axis->data_type(); + assert(axis_type == ir::DataType::INT32 || axis_type == ir::DataType::INT64); + + assert(axis->buffer()); + int32_t axis_value = + (axis_type == ir::DataType::INT32) + ? reinterpret_cast<const int32_t *>(axis->buffer())[0] + : static_cast<int32_t>(reinterpret_cast<const int64_t *>(axis->buffer())[0]); - auto output_shape = shape_inference::inferExpandDimsShape(input_shape, axis_buf[0]); + auto output_shape = shape_inference::inferExpandDimsShape(input_shape, axis_value); - dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape); + output->applyShape(output_shape); assert(output->buffer() != nullptr); } @@ -347,21 +410,26 @@ void DynamicShapeInferer::visit(const ir::operation::Fill &op) // check if output is not dynamic auto output_ind = op.getOutputs().at(0); auto output = _tensor_registry->getITensor(output_ind); - auto input_ind = op.getInputs().at(ir::operation::Fill::Input::INPUT); - auto input = _tensor_registry->getITensor(input_ind); - ir::Shape input_shape = input->getShape(); + auto shape_ind = op.getInputs().at(ir::operation::Fill::Input::SHAPE); + auto shape = _tensor_registry->getITensor(shape_ind); - if ((!input->is_dynamic()) && (!output->is_dynamic())) + if ((!shape->is_dynamic()) && (!output->is_dynamic())) return; - assert(input.get()->data_type() == ir::DataType::INT32); + const auto dims_type = shape->data_type(); + assert(dims_type == ir::DataType::INT32 || dims_type == ir::DataType::INT64); - auto input_buf = reinterpret_cast<const int32_t *>(input->buffer()); - assert(input_buf); + auto dims_buf = shape->buffer(); + assert(dims_buf); - auto output_shape = shape_inference::inferFillShape(input_shape, input_buf); + const auto &dims_shape = shape->getShape(); + const auto &output_shape = ((dims_type == ir::DataType::INT32) + ? shape_inference::inferFillShape<int32_t>( + dims_shape, reinterpret_cast<const int32_t *>(dims_buf)) + : shape_inference::inferFillShape<int64_t>( + dims_shape, reinterpret_cast<const int64_t *>(dims_buf))); - dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape); + output->applyShape(output_shape); assert(output->buffer() != nullptr); } @@ -384,7 +452,7 @@ void DynamicShapeInferer::visit(const ir::operation::FullyConnected &op) auto output_ind = op.getOutputs().at(0); auto output = _tensor_registry->getITensor(output_ind); - dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape); + output->applyShape(new_shape); assert(output->buffer() != nullptr); } @@ -416,7 +484,7 @@ void DynamicShapeInferer::visit(const ir::operation::Gather &op) auto output_ind = op.getOutputs().at(0); auto output = _tensor_registry->getITensor(output_ind); - dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape); + output->applyShape(new_shape); assert(output->buffer() != nullptr); } @@ -425,11 +493,122 @@ void DynamicShapeInferer::visit(const ir::operation::L2Normalization &op) handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::INPUT)); } +void DynamicShapeInferer::visit(const ir::operation::LSTM &op) +{ + const auto output_index{op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)}; + auto output = _tensor_registry->getITensor(output_index); + + const auto output_state_out_index{ + op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; + + const auto cell_state_out_index{op.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; + + const auto scratch_buffer_index{op.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; + + if (!output->is_dynamic() && + !(_tensor_registry->getITensor(output_state_out_index) != nullptr && + _tensor_registry->getITensor(output_state_out_index)->is_dynamic()) && + !(_tensor_registry->getITensor(cell_state_out_index) != nullptr && + _tensor_registry->getITensor(cell_state_out_index)->is_dynamic()) && + !(_tensor_registry->getITensor(scratch_buffer_index) != nullptr && + _tensor_registry->getITensor(cell_state_out_index)->is_dynamic())) + return; + + const auto input_index{op.getInputs().at(ir::operation::LSTM::Input::INPUT)}; + const auto input = _tensor_registry->getITensor(input_index); + const auto input_shape = input->getShape(); + + const auto input_to_output_weights_index{ + op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)}; + const auto input_to_output_weights = _tensor_registry->getITensor(input_to_output_weights_index); + const auto input_to_output_weights_shape = input_to_output_weights->getShape(); + + const auto recurrent_to_output_weights_index{ + op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)}; + const auto recurrent_to_output_weights = + _tensor_registry->getITensor(recurrent_to_output_weights_index); + const auto recurrent_to_output_weights_shape = recurrent_to_output_weights->getShape(); + + // re-sizing outputs + const int n_batch = + (input_shape.rank() == 3 && op.param().time_major) ? input_shape.dim(1) : input_shape.dim(0); + const int n_cell = input_to_output_weights_shape.dim(0); + const int n_output = recurrent_to_output_weights_shape.dim(1); + if (input_shape.rank() == 3) + { + if (op.param().time_major) + output->applyShape(ir::Shape{input_shape.dim(0), n_batch, n_output}); + else + output->applyShape(ir::Shape{n_batch, input_shape.dim(1), n_output}); + } + else + { + assert(input_shape.rank() == 2); + output->applyShape(ir::Shape{n_batch, n_output}); + } + assert(output->buffer() != nullptr); + + auto output_state_out = _tensor_registry->getITensor(output_state_out_index); + if (output_state_out != nullptr) + { + output_state_out->applyShape(ir::Shape{n_batch, n_output}); + assert(output_state_out->buffer() != nullptr); + } + + auto cell_state_out = _tensor_registry->getITensor(cell_state_out_index); + if (cell_state_out != nullptr) + { + cell_state_out->applyShape(ir::Shape{n_batch, n_cell}); + assert(cell_state_out->buffer() != nullptr); + } + + auto scratch_buffer = _tensor_registry->getITensor(scratch_buffer_index); + if (scratch_buffer != nullptr) + { + const auto input_to_input_weights_index{ + op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; + const auto recurrent_to_input_weights_index{ + op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; + + const auto input_to_input_weights_shape = + _tensor_registry->getITensor(input_to_input_weights_index)->getShape(); + bool has_input_to_input_weights = + input_to_input_weights_shape.dim(0) != 0 && input_to_input_weights_shape.dim(1) != 0; + + const auto recurrent_to_input_weights_shape = + _tensor_registry->getITensor(recurrent_to_input_weights_index)->getShape(); + bool has_recurrent_to_input_weights = + recurrent_to_input_weights_shape.dim(0) != 0 && recurrent_to_input_weights_shape.dim(1) != 0; + + // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG). + // true: no CIFG + // false: CIFG + bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights; + if (has_cifg_param) + { + scratch_buffer->applyShape(ir::Shape{n_batch, n_cell * 4}); + } + else + { + scratch_buffer->applyShape(ir::Shape{n_batch, n_cell * 3}); + } + assert(scratch_buffer->buffer() != nullptr); + } +} + void DynamicShapeInferer::visit(const ir::operation::MatrixBandPart &op) { handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::MatrixBandPart::INPUT)); } +void DynamicShapeInferer::visit(const ir::operation::DetectionPostProcess & /* op */) +{ + // NOTE DetectionPostProcess's undefined outputs' shape are decided on compile time + // by static shape inferer. + // DetectionPostProcess's outputs' shape are independent with input shape + // and decided by parameter value. +} + void DynamicShapeInferer::visit(const ir::operation::OneHot &op) { auto output_ind = op.getOutputs().at(0); @@ -452,7 +631,7 @@ void DynamicShapeInferer::visit(const ir::operation::OneHot &op) const auto axis_val = op.param().axis; ir::Shape new_shape = shape_inference::inferOnehotShape(indices_shape, *depth_buf, axis_val); - dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape); + output->applyShape(new_shape); assert(output->buffer() != nullptr); } @@ -488,7 +667,7 @@ void DynamicShapeInferer::visit(const ir::operation::Pack &op) ir::Shape new_shape = shape_inference::inferPackShape(input_shape, axis, rank, num); - dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape); + output->applyShape(new_shape); assert(output->buffer() != nullptr); } @@ -512,10 +691,10 @@ void DynamicShapeInferer::visit(const ir::operation::Pad &op) assert(pad_buf); auto output_shape = - shape_inference::inferPadShape(input->getShape(), pad_buf, pad->getShape().num_elements()); + shape_inference::inferPadShape(input->getShape(), pad_buf, pad->getShape().num_elements()); // change output shape and reallocate output tensor memory - dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape); + output->applyShape(output_shape); assert(output->buffer() != nullptr); } @@ -526,6 +705,26 @@ void DynamicShapeInferer::visit(const ir::operation::Permute & /* op */) // on-the-fly, as it must support inter-backend inference/allocation. } +void DynamicShapeInferer::visit(const ir::operation::Pool2D &op) +{ + // check if input is not dynamic + auto input_ind = op.getInputs().at(ir::operation::Pool2D::INPUT); + auto input = _tensor_registry->getITensor(input_ind); + + if (!input->is_dynamic()) + return; + + ir::Shape input_shape = input->getShape(); + + auto output_ind = op.getOutputs().at(0); + auto output = _tensor_registry->getITensor(output_ind); + + ir::Shape output_shape = shape_inference::inferPoolShape(input_shape, op.param()); + + output->applyShape(output_shape); + assert(output->buffer() != nullptr); +} + void DynamicShapeInferer::visit(const ir::operation::Pow &op) { handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Pow::Input::LHS), @@ -556,18 +755,18 @@ void DynamicShapeInferer::visit(const ir::operation::Range &op) if (output->data_type() == ir::DataType::FLOAT32) { new_shape = - shape_inference::inferRangeShape<float>(*reinterpret_cast<float *>(start_tensor->buffer()), - *reinterpret_cast<float *>(limit_tensor->buffer()), - *reinterpret_cast<float *>(delta_tensor->buffer())); + shape_inference::inferRangeShape<float>(*reinterpret_cast<float *>(start_tensor->buffer()), + *reinterpret_cast<float *>(limit_tensor->buffer()), + *reinterpret_cast<float *>(delta_tensor->buffer())); } else if (output->data_type() == ir::DataType::INT32) { new_shape = shape_inference::inferRangeShape<int32_t>( - *reinterpret_cast<int32_t *>(start_tensor->buffer()), - *reinterpret_cast<int32_t *>(limit_tensor->buffer()), - *reinterpret_cast<int32_t *>(delta_tensor->buffer())); + *reinterpret_cast<int32_t *>(start_tensor->buffer()), + *reinterpret_cast<int32_t *>(limit_tensor->buffer()), + *reinterpret_cast<int32_t *>(delta_tensor->buffer())); } - dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape); + output->applyShape(new_shape); assert(output->buffer() != nullptr); } @@ -611,7 +810,7 @@ void DynamicShapeInferer::visit(const ir::operation::Reduce &op) ir::Shape new_shape = shape_inference::inferReduceShape(input_shape, axes_vec, keep_dims); - dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape); + output->applyShape(new_shape); assert(output->buffer() != nullptr); } @@ -658,14 +857,14 @@ void DynamicShapeInferer::visit(const ir::operation::Reshape &op) int32_t *new_shape_buf = reinterpret_cast<int32_t *>(new_shape->buffer()); assert(new_shape_buf); - auto output_shape = shape_inference::inferReshapeShape( - new_shape_buf, new_shape->getShape().num_elements(), input->getShape().num_elements()); + auto output_shape = shape_inference::inferReshapeShape(input->getShape(), new_shape_buf, + new_shape->getShape().num_elements()); // if shape is changed, change output shape and reallocate output tensor memory if (output_shape != output->getShape() || output->buffer() == nullptr) { // change on output shape - dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape); + output->applyShape(output_shape); } assert(output->buffer() != nullptr); } @@ -674,14 +873,14 @@ void DynamicShapeInferer::visit(const ir::operation::Reshape &op) { // Let's check the new_shape option auto shape = op.param().new_shape; - auto output_shape = shape_inference::inferReshapeShape(shape.data(), shape.size(), - input->getShape().num_elements()); + auto output_shape = + shape_inference::inferReshapeShape(input->getShape(), shape.data(), shape.size()); // if shape is changed, change output shape and reallocate output tensor memory if (output_shape != output->getShape() || output->buffer() == nullptr) { // change on output shape - dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape); + output->applyShape(output_shape); } assert(output->buffer() != nullptr); } @@ -705,14 +904,35 @@ void DynamicShapeInferer::visit(const ir::operation::ResizeBilinear &op) return; // getting output shape from input shape and Params - auto output_shape = shape_inference::inferResizeBilinearShape( - input->getShape(), op.param().height_out, op.param().width_out); + int32_t height_out, width_out; + if (op.getInputs().size() == 2) + { + auto size_ind = op.getInputs().at(ir::operation::ResizeBilinear::Input::SIZE); + auto size = _tensor_registry->getITensor(size_ind); + if (size->data_type() == ir::DataType::INT32) + { + auto size_buf = reinterpret_cast<const int32_t *>(size->buffer()); + height_out = size_buf[0]; + width_out = size_buf[1]; + } + else + { + throw std::runtime_error("DynamicShapeInferer ResizeBilinear : Unsupported data type"); + } + } + else + { + height_out = op.param().height_out; + width_out = op.param().width_out; + } + auto output_shape = + shape_inference::inferResizeBilinearShape(input->getShape(), height_out, width_out); // if shape is changed, change output shape and reallocate output tensor memory if (output_shape != output->getShape() || output->buffer() == nullptr) { // change on output shape - dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape); + output->applyShape(output_shape); } assert(output->buffer() != nullptr); } @@ -744,12 +964,12 @@ void DynamicShapeInferer::visit(const ir::operation::Select &op) // Select output shpae ir::Shape new_shape = - shape_inference::inferSelectShape(input_cond_shape, input_true_shape, input_false_shape); + shape_inference::inferSelectShape(input_cond_shape, input_true_shape, input_false_shape); auto output_ind = op.getOutputs().at(0); auto output = _tensor_registry->getITensor(output_ind); - dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape); + output->applyShape(new_shape); assert(output->buffer() != nullptr); } @@ -768,7 +988,7 @@ void DynamicShapeInferer::visit(const ir::operation::Shape &op) ir::Shape output_shape; output_shape.append(input_shape.rank()); - dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape); + output->applyShape(output_shape); assert(output->buffer() != nullptr); } @@ -794,7 +1014,7 @@ void DynamicShapeInferer::visit(const ir::operation::Slice &op) ir::Shape new_shape = shape_inference::inferSliceShape(input_shape, begins_buf, sizes_buf); - dynamicTensorManagerOf(output)->applyShape(output_index, new_shape); + output->applyShape(new_shape); assert(output->buffer() != nullptr); } @@ -829,9 +1049,9 @@ void DynamicShapeInferer::visit(const ir::operation::SpaceToBatchND &op) auto padding_data = reinterpret_cast<int32_t *>(padding->buffer()); ir::Shape new_shape = shape_inference::inferSpaceToBatchNDShape( - input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data); + input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data); - dynamicTensorManagerOf(output)->applyShape(output_idx, new_shape); + output->applyShape(new_shape); assert(output->buffer() != nullptr); } @@ -840,27 +1060,37 @@ void DynamicShapeInferer::visit(const ir::operation::Split &op) const auto input_idx{op.getInputs().at(ir::operation::Split::Input::INPUT)}; const auto &input = _tensor_registry->getITensor(input_idx); - if (!input->is_dynamic()) + // Return if all tensors are not dynamic + bool has_dynamic = false; + for (const auto &output_idx : op.getOutputs()) + { + auto output = _tensor_registry->getITensor(output_idx); + has_dynamic |= output->is_dynamic(); + } + if (!input->is_dynamic() && !has_dynamic) { return; } auto input_shape = input->getShape(); - const auto axis = op.param().axis; + const auto axis_idx{op.getInputs().at(ir::operation::Split::Input::AXIS)}; + const auto &axis = _tensor_registry->getITensor(axis_idx); + + auto axis_value = *reinterpret_cast<const int32_t *>(axis->buffer()); const auto num_splits = op.param().num_splits; const auto rank = input_shape.rank(); - auto axis_resolved = axis < 0 ? axis + rank : axis; + axis_value = axis_value < 0 ? axis_value + rank : axis_value; - assert(0 <= axis_resolved && axis_resolved < rank); + assert(0 <= axis_value && axis_value < rank); - ir::Shape new_shape = shape_inference::inferSplitShape(input_shape, axis_resolved, num_splits); + ir::Shape new_shape = shape_inference::inferSplitShape(input_shape, axis_value, num_splits); for (int out_tensor_idx = 0; out_tensor_idx < num_splits; out_tensor_idx++) { auto output_ind = op.getOutputs().at(out_tensor_idx); auto output = _tensor_registry->getITensor(output_ind); - dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape); + output->applyShape(new_shape); assert(output->buffer() != nullptr); } } @@ -889,7 +1119,7 @@ void DynamicShapeInferer::visit(const ir::operation::Squeeze &op) auto output_ind = op.getOutputs().at(0); auto output = _tensor_registry->getITensor(output_ind); - dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape); + output->applyShape(new_shape); assert(output->buffer() != nullptr); } @@ -920,17 +1150,16 @@ void DynamicShapeInferer::visit(const ir::operation::StridedSlice &op) const auto rank = input_shape.rank(); auto op_params = shape_inference::buildStridedSliceParams( - reinterpret_cast<uint32_t *>(starts->buffer()), reinterpret_cast<uint32_t *>(ends->buffer()), - reinterpret_cast<uint32_t *>(strides->buffer()), begin_mask, end_mask, shrink_axis_mask, - rank); + reinterpret_cast<uint32_t *>(starts->buffer()), reinterpret_cast<uint32_t *>(ends->buffer()), + reinterpret_cast<uint32_t *>(strides->buffer()), begin_mask, end_mask, shrink_axis_mask, rank); auto output_index = op.getOutputs().at(0); auto output = _tensor_registry->getITensor(output_index); ir::Shape output_shape = - onert::shape_inference::inferStridedSliceShape(input_shape, op_params, rank); + onert::shape_inference::inferStridedSliceShape(input_shape, op_params, rank); - dynamicTensorManagerOf(output)->applyShape(output_index, output_shape); + output->applyShape(output_shape); assert(output->buffer() != nullptr); } @@ -952,10 +1181,12 @@ void DynamicShapeInferer::visit(const ir::operation::Tile &op) auto multiplier_buffer = reinterpret_cast<const int32_t *>(multiplier->buffer()); assert(multiplier_buffer); - auto output_shape = shape_inference::inferTileShape(input_shape, multiplier_buffer); + auto mult_shape = multiplier->getShape(); + auto output_shape = shape_inference::inferTileShape( + input_shape, multiplier_buffer, mult_shape.rank() == 0 ? 1 : mult_shape.dim(0)); // set output shape and output buffer - dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape); + output->applyShape(output_shape); assert(output->buffer() != nullptr); } @@ -967,17 +1198,49 @@ void DynamicShapeInferer::visit(const ir::operation::Transpose &op) // from op, access the buffer of second input to read new shape auto input_ind = op.getInputs().at(ir::operation::Transpose::Input::INPUT); - auto input_tensor = _tensor_registry->getITensor(input_ind); - auto input_shape = input_tensor->getShape(); + auto input = _tensor_registry->getITensor(input_ind); + auto input_shape = input->getShape(); + + /* + Here, the state after compilation (static shape inference) could be one of the following: + + input perms output execution-time shape inf required + ------------------------------------ -------------------------------- + case 1) static const static X + case 2) static non-const dynamic O + case 3) dynamic const dynamic O + case 4) dynamic non-const dynamic O - if (!input_tensor->is_dynamic()) + So, only when both input1 and ouput are static, we can skip dynamic shape inference. + */ + if ((!input->is_dynamic()) && (!output->is_dynamic())) return; - const auto perm{op.param().perm}; - // set output shape, based on input and params - ir::Shape new_shape = shape_inference::inferTransposeShape(input_shape, perm); + auto perm_ind = op.getInputs().at(ir::operation::Transpose::Input::PERMUTATION); + auto perm = _tensor_registry->getITensor(perm_ind); + + ir::Shape new_shape; + // TODO Change perm->dimension(0) == 0 to perm->num_elements() == 0 + if (perm->getShape().dim(0) == 0) // This condition means that perm is (n-1...0) + { + // Call by (n-1...0) + new_shape = shape_inference::inferTransposeShape(input_shape, nullptr, 0); + } + else + { + // Check rank + if (static_cast<size_t>(input->getShape().rank()) != perm->getShape().num_elements()) + { + throw std::runtime_error("DynamicShapeInferer failed, bad rank size: " + + std::to_string(perm->getShape().num_elements())); + } - dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape); + // set output shape, based on input and params + const auto perm_buffer = reinterpret_cast<const int32_t *>(perm->buffer()); + new_shape = + shape_inference::inferTransposeShape(input_shape, perm_buffer, perm->getShape().dim(0)); + } + output->applyShape(new_shape); assert(output->buffer() != nullptr); } @@ -1005,7 +1268,7 @@ void DynamicShapeInferer::visit(const ir::operation::Unpack &op) auto output_ind = op.getOutputs().at(out_tensor_idx); auto output = _tensor_registry->getITensor(output_ind); - dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape); + output->applyShape(new_shape); assert(output->buffer() != nullptr); } diff --git a/runtime/onert/core/src/exec/EdgeTensor.cc b/runtime/onert/core/src/exec/EdgeTensor.cc new file mode 100644 index 000000000..569a2f697 --- /dev/null +++ b/runtime/onert/core/src/exec/EdgeTensor.cc @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "EdgeTensor.h" + +namespace onert +{ +namespace exec +{ + +bool EdgeTensor::applyShape(const ir::Shape &new_shape) +{ + bool previously_dynamic = is_dynamic(); + if (!previously_dynamic || _buffer == nullptr) + { + // Always set shape - when buffer with same size was already allocated, shape could differ + setShape(new_shape); + set_dynamic(); + const auto total_size = get_info().total_size(); + _buffer = std::make_unique<uint8_t[]>(total_size); + } + else + { + auto previous_size = total_size(); + auto new_size = new_shape.num_elements() * ir::sizeOfDataType(data_type()); + if (previous_size != new_size) + { + setShape(new_shape); + set_dynamic(); + const auto total_size = get_info().total_size(); + _buffer = std::make_unique<uint8_t[]>(total_size); + } + else + { // when buffer with same size was already allocated, shape could differ + setShape(new_shape); + } + } + return true; +} + +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/EdgeTensor.h b/runtime/onert/core/src/exec/EdgeTensor.h new file mode 100644 index 000000000..8df79c389 --- /dev/null +++ b/runtime/onert/core/src/exec/EdgeTensor.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_EDGE_TENSOR_H__ +#define __ONERT_EXEC_EDGE_TENSOR_H__ + +#include "backend/IPortableTensor.h" + +#include <memory> + +namespace onert +{ +namespace exec +{ + +class EdgeTensor : public backend::IPortableTensor +{ +public: + EdgeTensor(const ir::OperandInfo &info, ir::Layout layout) + : IPortableTensor(info), _layout{layout}, _buffer{nullptr}, _ref_count{0} + { + } + ~EdgeTensor() = default; + + uint8_t *buffer() const override { return _buffer.get(); } + ir::Layout layout() const override { return _layout; } + void set_dynamic() override { _info.setDynamic(); } + bool applyShape(const ir::Shape &new_shape) override; + void setShape(const ir::Shape &new_shape) override { _info.shape(new_shape); } + + void allocate_buffer() + { + const auto total_size = _info.total_size(); + _buffer = std::make_unique<uint8_t[]>(total_size); + _ref_count = 1; + } + + void increase_ref() { _ref_count++; } + + void decrease_ref() + { + assert(_ref_count > 0); + _ref_count--; + if (_ref_count == 0) + { + _buffer.reset(); + } + } + +private: + ir::Layout _layout; + std::unique_ptr<uint8_t[]> _buffer; + int32_t _ref_count; +}; + +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_EDGE_TENSOR_H__ diff --git a/runtime/onert/core/src/exec/ExecTime.cc b/runtime/onert/core/src/exec/ExecTime.cc index 6bf2744a9..4b82655b9 100644 --- a/runtime/onert/core/src/exec/ExecTime.cc +++ b/runtime/onert/core/src/exec/ExecTime.cc @@ -14,12 +14,10 @@ * limitations under the License. */ -#include "exec/ExecTime.h" +#include "ExecTime.h" -#include <fstream> -#include <cassert> -#include <limits> #include <algorithm> +#include <cassert> namespace onert { diff --git a/runtime/onert/core/src/exec/ExecTime.h b/runtime/onert/core/src/exec/ExecTime.h index 846d0930b..95f460053 100644 --- a/runtime/onert/core/src/exec/ExecTime.h +++ b/runtime/onert/core/src/exec/ExecTime.h @@ -34,7 +34,7 @@ class ExecTime { public: explicit ExecTime(const std::vector<const backend::Backend *> &backends) - : _json(backends, _measurements) + : _json(backends, _measurements) { } @@ -94,7 +94,7 @@ public: /** * @brief Update metrics file with new data. */ - void uploadOperationsExecTime() const { _json.uploadOperationsExecTime(); } + void storeOperationsExecTime() const { _json.storeOperationsExecTime(); } static const int64_t NOT_FOUND = -1; private: diff --git a/runtime/onert/core/src/exec/ExecTime.test.cc b/runtime/onert/core/src/exec/ExecTime.test.cc new file mode 100644 index 000000000..939184e4e --- /dev/null +++ b/runtime/onert/core/src/exec/ExecTime.test.cc @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ExecTime.h" + +#include "backend/IConfig.h" +#include "backend/Backend.h" + +#include <gtest/gtest.h> + +#include <string> + +namespace +{ +using namespace onert; +using namespace exec; +using namespace backend; + +struct MockConfig : public IConfig +{ + std::string id() override { return "b1"; } + bool initialize() override { return true; }; + bool supportPermutation() override { return false; } + ir::Layout supportLayout(const ir::IOperation &, ir::Layout) override + { + return ir::Layout::UNKNOWN; + } + bool supportDynamicTensor() override { return false; } + bool supportFP16() override { return false; } +}; + +struct MockBackend : public ::onert::backend::Backend +{ + std::shared_ptr<onert::backend::IConfig> config() const override + { + return std::make_shared<MockConfig>(); + } + std::unique_ptr<onert::backend::BackendContext> newContext(ContextData &&) const override + { + return nullptr; + } +}; + +TEST(ExecTime, roundtrip_ok) +{ + const auto *b = new MockBackend(); + std::vector<const Backend *> bs = {b}; + { + ExecTime et(bs); + et.updateOperationExecTime(b, "op1", true, 100, 100); + et.updateOperationExecTime(b, "op1", true, 200, 200); + et.updateOperationExecTime(b, "op1", false, 100, 888); + et.storeOperationsExecTime(); + } + { + ExecTime et(bs); + auto time = et.getOperationExecTime(b, "op1", true, 100); + ASSERT_EQ(time, 100); + // Check interpolation + time = et.getOperationExecTime(b, "op1", true, 150); + ASSERT_EQ(time, 150); + time = et.getOperationExecTime(b, "op1", false, 100); + ASSERT_EQ(time, 888); + et.storeOperationsExecTime(); + } + // clean up + EXPECT_EQ(remove("exec_time.json"), 0); +} + +TEST(ExecTime, structure) +{ + + const auto *b = new MockBackend(); + std::vector<const Backend *> bs = {b}; + { + ExecTime et(bs); + et.updateOperationExecTime(b, "op1", true, 100, 100); + et.updateOperationExecTime(b, "op1", true, 200, 200); + et.storeOperationsExecTime(); + } + { + ExecTime et(bs); + auto time = et.getOperationExecTime(b, "op1", true, 100); + ASSERT_EQ(time, 100); + // Check interpolation + time = et.getOperationExecTime(b, "op1", true, 200); + ASSERT_EQ(time, 200); + et.storeOperationsExecTime(); + } + // clean up + EXPECT_EQ(remove("exec_time.json"), 0); +} +} // unnamed namespace diff --git a/runtime/onert/core/src/exec/Execution.cc b/runtime/onert/core/src/exec/Execution.cc index 7feb3ab68..895a82ff8 100644 --- a/runtime/onert/core/src/exec/Execution.cc +++ b/runtime/onert/core/src/exec/Execution.cc @@ -16,6 +16,8 @@ #include "exec/Execution.h" +#include "ir/DataType.h" +#include "train/TrainableExecutors.h" #include "util/logging.h" namespace onert @@ -23,116 +25,120 @@ namespace onert namespace exec { -Execution::Execution(const std::shared_ptr<ExecutorMap> &executors) : _executors{executors} +Execution::Execution(const std::shared_ptr<IExecutors> &executors) : _executors{executors} { assert(executors != nullptr); - assert(executors->at(ir::SubgraphIndex{0}) != nullptr); - const auto &primary_subg = primary_subgraph(); - _io_desc.inputs.resize(primary_subg.getInputs().size()); - _io_desc.outputs.resize(primary_subg.getOutputs().size()); + assert(executors->entryExecutor() != nullptr); + + // Initialize I/O description + _ctx.desc.inputs.resize(_executors->inputSize()); + for (uint32_t i = 0; i < _executors->inputSize(); ++i) + _ctx.desc.inputs.at(i) = std::make_unique<InputDesc>(_executors->inputInfo(ir::IOIndex(i))); + + _ctx.desc.outputs.resize(_executors->outputSize()); + for (uint32_t i = 0; i < _executors->outputSize(); ++i) + _ctx.desc.outputs.at(i) = std::make_unique<OutputDesc>(_executors->outputInfo(ir::IOIndex(i))); + _ctx.shape_updated = false; + + // Initialize options + ExecutionOptions::fromGlobalConfig(_ctx.options); } void Execution::changeInputShape(const ir::IOIndex &index, const ir::Shape &new_shape) { - // This should be called BEFORE setInput. - if (_io_desc.inputs.at(index.value()) != 0) - throw std::runtime_error("Error in calling order"); - // This will be used later to set input tensor dynamic // Note that 'compiled' model will not be updated with new_shape // but new_shape will change model input shape while 'running' the model - _io_desc.dynamic_input_shapes[index] = new_shape; -} - -// TODO Remove default parameter -void Execution::setInput(const ir::IOIndex &index, const void *buffer, size_t length, - ir::Layout layout) -{ - const auto input_index = primary_subgraph().getInputs().at(index); - const auto info = primary_subgraph().operands().at(input_index).info(); - - // TODO handle when (!buffer && length != 0) : setting the input as an optional tensor - - // check if size enough for input is passed - // if input_shape_sig is set, input_shape_sig overrides shape in info - // note: input_shape_sig contains shape passed by nnfw_set_input_tensorinfo() + auto &input_desc = _ctx.desc.inputs.at(index.value()); + if (new_shape != input_desc->info.shape()) { - auto input_shape_sig = _io_desc.dynamic_input_shapes.find(index); - auto size_required = (input_shape_sig != _io_desc.dynamic_input_shapes.end()) - ? input_shape_sig->second.num_elements() * - onert::ir::sizeOfDataType(info.typeInfo().type()) - : info.total_size(); + input_desc->info.shape(new_shape); + _ctx.shape_updated = true; - if (length < size_required) - { - throw std::runtime_error{"Too small length"}; - } + VERBOSE(Execution) << "Model input shape will be changed at the start of execute()" + << "(index: " << index << ")" << std::endl; } - - _io_desc.inputs.at(index.value()) = std::make_unique<InputDesc>(info, buffer, length, layout); } // TODO Remove default parameter -void Execution::setInput(const ir::IOIndex &index, const ir::TypeInfo &type, const ir::Shape &shape, - const void *buffer, size_t length, ir::Layout layout) +void Execution::setInput(const ir::IOIndex &index, const void *buffer, size_t length) { - auto info = ir::OperandInfo::createStaticInfo(shape, type); - - if (length < info.total_size()) - { - throw std::runtime_error{"Too small length"}; - } - - _io_desc.inputs.at(index.value()) = std::make_unique<InputDesc>(info, buffer, length, layout); + // Length validation in execute(): datatype can be changed by API call + auto &input_desc = _ctx.desc.inputs.at(index.value()); + input_desc->buffer = buffer; + input_desc->size = length; } -// TODO Remove default parameter -void Execution::setOutput(const ir::IOIndex &index, void *buffer, size_t length, ir::Layout layout) +void Execution::setInput(const ir::IOIndex &index, const ir::Shape &shape, const void *buffer, + size_t length) { - const auto output_index = primary_subgraph().getOutputs().at(index); - const auto info = primary_subgraph().operands().at(output_index).info(); - - if (length < info.total_size()) - { - throw std::runtime_error{"Too small length"}; - } - - _io_desc.outputs.at(index.value()) = std::make_unique<OutputDesc>(info, buffer, length, layout); + changeInputShape(index, shape); + setInput(index, buffer, length); } -// TODO Remove default parameter -void Execution::setOutput(const ir::IOIndex &index, const ir::TypeInfo &type, - const ir::Shape &shape, void *buffer, size_t length, ir::Layout layout) +void Execution::setOutput(const ir::IOIndex &index, void *buffer, size_t length) { - auto info = ir::OperandInfo::createStaticInfo(shape, type); + // Length validation in execute() + // - datatype can be changed by API call + // - shape can be changed by dynamic shape inference + auto &output_desc = _ctx.desc.outputs.at(index.value()); + output_desc->buffer = buffer; + output_desc->size = length; +} - if (length < info.total_size()) - { - throw std::runtime_error{"Too small length"}; - } +void Execution::setOutput(const ir::IOIndex &index, const ir::Shape &shape, void *buffer, + size_t length) +{ + auto &output_desc = _ctx.desc.outputs.at(index.value()); + output_desc->info.shape(shape); - _io_desc.outputs.at(index.value()) = std::make_unique<OutputDesc>(info, buffer, length, layout); + setOutput(index, buffer, length); } void Execution::setInputLayout(const ir::IOIndex &index, ir::Layout layout) { - const auto &input_desc = _io_desc.inputs.at(index.value()); - _io_desc.inputs.at(index.value()) = - std::make_unique<InputDesc>(input_desc->info, input_desc->buffer, input_desc->size, layout); + _ctx.desc.inputs.at(index.value())->layout = layout; } void Execution::setOutputLayout(const ir::IOIndex &index, ir::Layout layout) { - const auto &output_desc = _io_desc.outputs.at(index.value()); - _io_desc.outputs.at(index.value()) = std::make_unique<OutputDesc>( - output_desc->info, output_desc->buffer, output_desc->size, layout); + _ctx.desc.outputs.at(index.value())->layout = layout; +} + +void Execution::setInputType(const ir::IOIndex &index, const ir::TypeInfo &typeInfo) +{ + _ctx.desc.inputs.at(index.value())->info.typeInfo(typeInfo); + _ctx.shape_updated = true; +} + +void Execution::setOutputType(const ir::IOIndex &index, const ir::TypeInfo &typeInfo) +{ + _ctx.desc.outputs.at(index.value())->info.typeInfo(typeInfo); + _ctx.shape_updated = true; } void Execution::execute() { VERBOSE(Execution) << "Start execution" << std::endl; - primary_executor()->execute(_io_desc); + // Input length validation check + for (const auto &input : _ctx.desc.inputs) + { + if (input->info.total_size() > input->size) + throw std::runtime_error{"Too small input buffer length"}; + } + + // Output length validation check + if (!_ctx.shape_updated) + { + for (const auto &output : _ctx.desc.outputs) + { + if (output->info.total_size() > output->size) + throw std::runtime_error{"Too small output buffer length"}; + } + } + + _executors->execute(_ctx); finished = true; VERBOSE(Execution) << "Execution finished" << std::endl; @@ -155,28 +161,66 @@ void Execution::waitFinish() bool Execution::isFinished(void) const { return finished; } -ir::Shape Execution::getInputShape(ir::IOIndex ind) const +void Execution::train(uint32_t training_step) +{ + auto execs = dynamic_cast<exec::train::TrainableExecutors *>(_executors.get()); + if (!execs) + { + throw std::runtime_error{"Supported only TrainableExecutors"}; + } + + execs->train(_ctx, training_step); + finished = true; +} + +float Execution::getLoss(const ir::IOIndex &ind) { - auto itr = _io_desc.dynamic_input_shapes.find(ind); - if (itr == _io_desc.dynamic_input_shapes.end()) + auto execs = dynamic_cast<exec::train::TrainableExecutors *>(_executors.get()); + if (!execs) { - auto operand_idx = primary_subgraph().getInputs().at(ind.value()); - return primary_subgraph().operands().at(operand_idx).shape(); + throw std::runtime_error{"Supported only TrainableExecutors"}; } - else + + return execs->getLoss(ind); +} + +void Execution::iterateTrainableTensors( + const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn) + const +{ + auto execs = dynamic_cast<exec::train::TrainableExecutors *>(_executors.get()); + if (!execs) { - return itr->second; + throw std::runtime_error{"Supported only TrainableExecutors"}; } + execs->iterateTrainableTensors(fn); } +ir::Shape Execution::getInputShape(ir::IOIndex ind) const +{ + return _ctx.desc.inputs.at(ind.value())->info.shape(); +} + +// NNAPI return fail if ANeuralNetworksExecution_getOutputOperandRank or +// ANeuralNetworksExecution_getOutputOperandDimensions is called before execution. +// On the other hand, NNFW API return static shape inference result if nnfw_output_tensorinfo is +// called before execution. +// To handle both case, this method retun static shape inference result and fail will be handled on +// NNAPI frontend. ir::Shape Execution::getOutputShape(ir::IOIndex ind) const { - if (!isFinished()) - throw std::runtime_error("Cannot get output shape before execution is finished"); + return _ctx.desc.outputs.at(ind.value())->info.shape(); +} - const auto &output_desc = _io_desc.outputs.at(ind.value()); +size_t Execution::getInputTotalSize(ir::IOIndex ind) const +{ + // TODO Support dynamic shape + return _ctx.desc.inputs.at(ind.value())->info.total_size(); +} - return output_desc->info.shape(); +size_t Execution::getOutputTotalSize(ir::IOIndex ind) const +{ + return _ctx.desc.outputs.at(ind.value())->info.total_size(); } } // namespace exec diff --git a/runtime/onert/core/src/exec/Execution.test.cc b/runtime/onert/core/src/exec/Execution.test.cc new file mode 100644 index 000000000..15f94445a --- /dev/null +++ b/runtime/onert/core/src/exec/Execution.test.cc @@ -0,0 +1,783 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "exec/Execution.h" + +#include "compiler/Compiler.h" +#include "compiler/CompilerFactory.h" +#include "ir/Graph.h" +#include "ir/operation/BinaryArithmetic.h" +#include "util/TracingCtx.h" + +#include <gtest/gtest.h> +#include <thread> + +namespace +{ + +using namespace onert::ir; + +class CompiledMockUpModel +{ +public: + CompiledMockUpModel() + { + // Model: two elementwise add operation + // model input: lhs, rhs1 + // model output: second add result (result2) + // constant: rhs2 + // result1 <= (lhs + rhs) + // result2 <= (result1 + rhs2) + // lhs, rhs1, rh2, result1, result2 shape: {1, 2, 2, 1} + // activation: none (constant) + graph = std::make_shared<Graph>(); + // 1st add operands (result1 <= lhs + rhs1) + Shape shape{1, 2, 2, 1}; + TypeInfo type{DataType::FLOAT32}; + static float rhs2_data[4] = {3, 1, -1, 5}; + auto operand_lhs = graph->addOperand(shape, type); + auto operand_rhs1 = graph->addOperand(shape, type); + auto operand_result1 = graph->addOperand(shape, type); + auto operand_rhs2 = graph->addOperand(shape, type); + auto operand_result2 = graph->addOperand(shape, type); + graph->operands() + .at(operand_rhs2) + .data(std::make_unique<CachedData>(reinterpret_cast<const uint8_t *>(&rhs2_data), 16)); + // 2nd add operations (result2 <= result1 + rhs2) + operation::BinaryArithmetic::Param param1; + param1.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD; + param1.activation = Activation::NONE; + auto input_set1 = OperandIndexSequence{operand_lhs, operand_rhs1}; + auto output_set1 = OperandIndexSequence{operand_result1}; + graph->addOperation( + std::make_unique<operation::BinaryArithmetic>(input_set1, output_set1, param1)); + operation::BinaryArithmetic::Param param2; + param2.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD; + param2.activation = Activation::NONE; + auto input_set2 = OperandIndexSequence{operand_result1, operand_rhs2}; + auto output_set2 = OperandIndexSequence{operand_result2}; + graph->addOperation( + std::make_unique<operation::BinaryArithmetic>(input_set2, output_set2, param2)); + // Identify model inputs and outputs + graph->addInput(operand_lhs); + graph->addInput(operand_rhs1); + graph->addOutput(operand_result2); + graph->verify(); + + // Compile + auto model = std::make_shared<onert::ir::Model>(); + model->push(onert::ir::SubgraphIndex{0}, graph); + coptions = onert::compiler::CompilerOptions::fromGlobalConfig(); + onert::compiler::Compiler compiler{model, coptions.get()}; + artifact = compiler.compile(); + } + +public: + std::shared_ptr<Graph> graph; + std::unique_ptr<onert::compiler::CompilerOptions> coptions; + std::shared_ptr<onert::compiler::CompilerArtifact> artifact; +}; + +class CompiledMockUpMultiModel +{ +public: + CompiledMockUpMultiModel() + { + // Model0: a float elementwise add operation + // Model0 input: lhs0, rhs0 + // Model0 output: add result (result0) + + // Model1: a qasymm8 elementwise add operation + // Model1 input: result0, rhs1 + // Model1 output: add result (result1) + + // Model2: a float elementwise add operation + // Model2 input: result0, result1 + // Model2 output: add result (result2) + + // constant: rhs2 + // result0 <= (lhs0 + rhs0) + // result1 <= (result0 + rhs1) + // result2 <= (result0 + result1) + // lhs0, rhs0, rh1, result0, result1, result2 shape: {1, 2, 2, 1} + // activation: none (constant) + + // Update edge information + edges.pkg_inputs.emplace_back(ModelIndex{0}, SubgraphIndex{0}, IOIndex{0}); + edges.pkg_inputs.emplace_back(ModelIndex{0}, SubgraphIndex{0}, IOIndex{1}); + edges.pkg_outputs.emplace_back(ModelIndex{2}, SubgraphIndex{0}, IOIndex{0}); + // From + const auto result0 = IODesc{ModelIndex{0}, SubgraphIndex{0}, IOIndex{0}}; + const auto result1 = IODesc{ModelIndex{1}, SubgraphIndex{0}, IOIndex{0}}; + // To + const auto lhs1 = IODesc{ModelIndex{1}, SubgraphIndex{0}, IOIndex{0}}; + const auto lhs2 = IODesc{ModelIndex{2}, SubgraphIndex{0}, IOIndex{0}}; + const auto rhs2 = IODesc{ModelIndex{2}, SubgraphIndex{0}, IOIndex{1}}; + edges.edges.insert({result0, lhs1}); + edges.edges.insert({result0, lhs2}); + edges.edges.insert({result1, rhs2}); + + for (size_t i = 0; i < 3; ++i) + { + graphs.emplace_back(std::make_shared<Graph>()); + } + Shape shape{1, 2, 2, 1}; + + // Model0's add operands (result1 <= lhs0 + rhs0) + DataType types[3] = {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::FLOAT32}; + auto operand_lhs0 = graphs[0]->addOperand(shape, TypeInfo{types[0]}); + auto operand_rhs0 = graphs[0]->addOperand(shape, TypeInfo{types[0]}); + auto operand_result0 = graphs[0]->addOperand(shape, TypeInfo{types[0]}); + + // Model0's add operation + operation::BinaryArithmetic::Param param0; + param0.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD; + param0.activation = Activation::NONE; + auto input_set0 = OperandIndexSequence{operand_lhs0, operand_rhs0}; + auto output_set0 = OperandIndexSequence{operand_result0}; + graphs[0]->addOperation( + std::make_unique<operation::BinaryArithmetic>(input_set0, output_set0, param0)); + + // Model0's inputs/outputs + graphs[0]->addInput(operand_lhs0); + graphs[0]->addInput(operand_rhs0); + graphs[0]->addOutput(operand_result0); + graphs[0]->verify(); + + // Model1's add operands (result2 <= Model0 result + rhs1) + // static float rhs1_data[4] = {3, 1, -1, 5}; + static uint8_t rhs1_data[4] = {131, 129, 127, 133}; + const float scale = 1; + const int32_t zero_point = 128; + auto operand_lhs1 = graphs[1]->addOperand(shape, TypeInfo{types[1], scale, zero_point}); + auto operand_rhs1 = graphs[1]->addOperand(shape, TypeInfo{types[1], scale, zero_point}); + auto operand_result1 = graphs[1]->addOperand(shape, TypeInfo{types[1], scale, zero_point}); + graphs[1] + ->operands() + .at(operand_rhs1) + .data(std::make_unique<CachedData>(reinterpret_cast<const uint8_t *>(&rhs1_data), 4)); + + // Model1's add operation + operation::BinaryArithmetic::Param param1; + param1.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD; + param1.activation = Activation::NONE; + auto input_set1 = OperandIndexSequence{operand_lhs1, operand_rhs1}; + auto output_set1 = OperandIndexSequence{operand_result1}; + graphs[1]->addOperation( + std::make_unique<operation::BinaryArithmetic>(input_set1, output_set1, param1)); + + // Model1's inputs/outputs + graphs[1]->addInput(operand_lhs1); + graphs[1]->addOutput(operand_result1); + graphs[1]->verify(); + + // Model2's additional operands (result3 <= Model0 result + Model1 result) + auto operand_lhs2 = graphs[2]->addOperand(shape, TypeInfo{types[2]}); + auto operand_rhs2 = graphs[2]->addOperand(shape, TypeInfo{types[2]}); + auto operand_result2 = graphs[2]->addOperand(shape, TypeInfo{types[2]}); + + // Model2's add operation + operation::BinaryArithmetic::Param param2; + param2.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD; + param2.activation = Activation::NONE; + auto input_set2 = OperandIndexSequence{operand_lhs2, operand_rhs2}; + auto output_set2 = OperandIndexSequence{operand_result2}; + graphs[2]->addOperation( + std::make_unique<operation::BinaryArithmetic>(input_set2, output_set2, param2)); + + // Model1's inputs/outputs + graphs[2]->addInput(operand_lhs2); + graphs[2]->addInput(operand_rhs2); + graphs[2]->addOutput(operand_result2); + graphs[2]->verify(); + + // Compile + compile(); + } + +public: + void compile() + { + auto nnpkg = std::make_shared<onert::ir::NNPkg>(); + coptions = onert::compiler::CompilerOptions::fromGlobalConfig(); + + for (uint16_t i = 0; i < graphs.size(); ++i) + { + auto model = std::make_shared<onert::ir::Model>(); + model->push(SubgraphIndex{0}, graphs[i]); + + nnpkg->push(onert::ir::ModelIndex{i}, std::move(model)); + } + for (const auto &pkg_input : edges.pkg_inputs) + { + nnpkg->addInput(pkg_input); + } + for (const auto &pkg_output : edges.pkg_outputs) + { + nnpkg->addOutput(pkg_output); + } + for (const auto &edge : edges.edges) + { + nnpkg->addEdge(edge.from, edge.to); + } + auto compiler = onert::compiler::CompilerFactory::get().create(nnpkg, coptions.get()); + nnpkg.reset(); + artifact = compiler->compile(); + } + +public: + std::vector<std::shared_ptr<Graph>> graphs; + std::unique_ptr<onert::compiler::CompilerOptions> coptions; + std::shared_ptr<onert::compiler::CompilerArtifact> artifact; + ModelEdges edges; +}; + +class CompiledMockUpQuantModel +{ +public: + CompiledMockUpQuantModel() + { + // Model: two elementwise add operation + // model input: lhs, rhs1 + // model output: second add result (result2) + // constant: rhs2 + // result1 <= (lhs + rhs) + // result2 <= (result1 + rhs2) + // lhs, rhs1, rh2, result1, result2 shape: {1, 2, 2, 1} + // activation: none (constant) + graph = std::make_shared<Graph>(); + // 1st add operands (result1 <= lhs + rhs1) + Shape shape{1, 2, 2, 1}; + TypeInfo type{DataType::QUANT_UINT8_ASYMM, 1.0f, 128}; + static uint8_t rhs2_data[4] = {131, 129, 127, 133}; + auto operand_lhs = graph->addOperand(shape, type); + auto operand_rhs1 = graph->addOperand(shape, type); + auto operand_result1 = graph->addOperand(shape, type); + auto operand_rhs2 = graph->addOperand(shape, type); + auto operand_result2 = graph->addOperand(shape, type); + graph->operands() + .at(operand_rhs2) + .data(std::make_unique<CachedData>(reinterpret_cast<const uint8_t *>(&rhs2_data), 4)); + // 2nd add operations (result2 <= result1 + rhs2) + operation::BinaryArithmetic::Param param1; + param1.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD; + param1.activation = Activation::NONE; + auto input_set1 = OperandIndexSequence{operand_lhs, operand_rhs1}; + auto output_set1 = OperandIndexSequence{operand_result1}; + graph->addOperation( + std::make_unique<operation::BinaryArithmetic>(input_set1, output_set1, param1)); + operation::BinaryArithmetic::Param param2; + param2.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD; + param2.activation = Activation::NONE; + auto input_set2 = OperandIndexSequence{operand_result1, operand_rhs2}; + auto output_set2 = OperandIndexSequence{operand_result2}; + graph->addOperation( + std::make_unique<operation::BinaryArithmetic>(input_set2, output_set2, param2)); + // Identify model inputs and outputs + graph->addInput(operand_lhs); + graph->addInput(operand_rhs1); + graph->addOutput(operand_result2); + graph->verify(); + + // Compile + auto model = std::make_shared<onert::ir::Model>(); + model->push(onert::ir::SubgraphIndex{0}, graph); + coptions = onert::compiler::CompilerOptions::fromGlobalConfig(); + onert::compiler::Compiler compiler{model, coptions.get()}; + artifact = compiler.compile(); + } + +public: + std::shared_ptr<Graph> graph; + std::unique_ptr<onert::compiler::CompilerOptions> coptions; + std::shared_ptr<onert::compiler::CompilerArtifact> artifact; +}; + +TEST(ExecInstance, simple) +{ + auto mockup = CompiledMockUpModel(); + auto graph = mockup.graph; + auto executors = mockup.artifact->_executors; + + auto input1 = IOIndex{0}; + auto input2 = IOIndex{1}; + auto output = IOIndex{0}; + + const float input1_buffer[4] = {1, 0, -1, -2}; + const float input2_buffer[4] = {1, -3, 2, -4}; + float output_buffer[4] = {}; + const float output_expected[4] = {5, -2, 0, -1}; + + onert::exec::Execution execution{executors}; + + execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16); + execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16); + execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16); + execution.execute(); + + for (auto i = 0; i < 4; i++) + { + EXPECT_EQ(output_buffer[i], output_expected[i]); + } +} + +TEST(ExecInstance, neg_small_outputbuffer) +{ + auto mockup = CompiledMockUpModel(); + auto graph = mockup.graph; + auto executors = mockup.artifact->_executors; + + auto input1 = IOIndex{0}; + auto input2 = IOIndex{1}; + auto output = IOIndex{0}; + + const float input1_buffer[4] = {1, 0, -1, -2}; + const float input2_buffer[4] = {1, -3, 2, -4}; + float output_buffer[2] = {}; + + onert::exec::Execution execution{executors}; + + execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16); + execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16); + execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 8); + EXPECT_ANY_THROW(execution.execute()); +} + +TEST(ExecInstance, neg_small_inoutsize) +{ + auto mockup = CompiledMockUpModel(); + auto graph = mockup.graph; + auto executors = mockup.artifact->_executors; + + auto input1 = IOIndex{0}; + auto input2 = IOIndex{1}; + auto output = IOIndex{0}; + + const float input1_buffer[2] = {1, 0}; + const float input2_buffer[2] = {1, -3}; + const auto new_shape = onert::ir::Shape({1, 1, 2, 1}); + float output_buffer[2] = {}; + + onert::exec::Execution execution{executors}; + + execution.setInput(input1, new_shape, reinterpret_cast<const void *>(input1_buffer), 8); + execution.setInput(input2, new_shape, reinterpret_cast<const void *>(input2_buffer), 2); + EXPECT_THROW(execution.execute(), std::exception); + + // Not throw exception because input shape is changed and output buffer is enough + execution.setInput(input2, new_shape, reinterpret_cast<const void *>(input2_buffer), 8); + execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16); + execution.execute(); + + execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 8); + // Throw exception by shape inference because output buffer size is small: + // output shape is {1, 2, 2, 1} + EXPECT_THROW(execution.execute(), std::exception); +} + +TEST(ExecInstance, twoCompile) +{ + auto mockup = CompiledMockUpModel(); + auto graph = mockup.graph; + auto executors1 = mockup.artifact->_executors; + onert::exec::Execution execution1{executors1}; + + auto input1 = IOIndex{0}; + auto input2 = IOIndex{1}; + auto output = IOIndex{0}; + + const float exe1_input1_buffer[4] = {1, 0, -1, -2}; + const float exe1_input2_buffer[4] = {1, -3, 2, -4}; + float exe1_output_buffer[4] = {}; + const float exe1_output_expected[4] = {5, -2, 0, -1}; + + execution1.setInput(input1, reinterpret_cast<const void *>(exe1_input1_buffer), 16); + execution1.setInput(input2, reinterpret_cast<const void *>(exe1_input2_buffer), 16); + execution1.setOutput(output, reinterpret_cast<void *>(exe1_output_buffer), 16); + + // Make new executor: compile again + auto model = std::make_shared<onert::ir::Model>(); + model->push(onert::ir::SubgraphIndex{0}, graph); + auto coptions = onert::compiler::CompilerOptions::fromGlobalConfig(); + onert::compiler::Compiler compiler{model, coptions.get()}; + std::shared_ptr<onert::compiler::CompilerArtifact> artifact = compiler.compile(); + onert::exec::Execution execution2{artifact->_executors}; + + const float exe2_input1_buffer[4] = {2, 1, -2, 0}; + const float exe2_input2_buffer[4] = {-3, 3, 1, 2}; + float exe2_output_buffer[4] = {}; + const float exe2_output_expected[4] = {2, 5, -2, 7}; + + execution2.setInput(input1, reinterpret_cast<const void *>(exe2_input1_buffer), 16); + execution2.setInput(input2, reinterpret_cast<const void *>(exe2_input2_buffer), 16); + execution2.setOutput(output, reinterpret_cast<void *>(exe2_output_buffer), 16); + + execution1.execute(); + execution2.execute(); + + for (auto i = 0; i < 4; i++) + { + EXPECT_EQ(exe1_output_buffer[i], exe1_output_expected[i]); + EXPECT_EQ(exe2_output_buffer[i], exe2_output_expected[i]); + } +} + +// Support two initialized execution instance then ordered execution +TEST(ExecInstance, twoExecution) +{ + auto mockup = CompiledMockUpModel(); + auto executors = mockup.artifact->_executors; + auto input1 = IOIndex{0}; + auto input2 = IOIndex{1}; + auto output1 = IOIndex{0}; + + const float exe1_input1_buffer[4] = {1, 0, -1, -2}; + const float exe1_input2_buffer[4] = {1, -3, 2, -4}; + float exe1_output_buffer[4] = {}; + const float exe1_output_expected[4] = {5, -2, 0, -1}; + const float exe2_output_expected[4] = {2, 5, -2, 7}; + + onert::exec::Execution execution1{executors}; + execution1.setInput(input1, reinterpret_cast<const void *>(exe1_input1_buffer), 16); + execution1.setInput(input2, reinterpret_cast<const void *>(exe1_input2_buffer), 16); + execution1.setOutput(output1, reinterpret_cast<void *>(exe1_output_buffer), 16); + + const float exe2_input1_buffer[4] = {2, 1, -2, 0}; + const float exe2_input2_buffer[4] = {-3, 3, 1, 2}; + float exe2_output_buffer[4] = {}; + + // Make new execution + onert::exec::Execution execution2{executors}; + execution2.setInput(input1, reinterpret_cast<const void *>(exe2_input1_buffer), 16); + execution2.setInput(input2, reinterpret_cast<const void *>(exe2_input2_buffer), 16); + execution2.setOutput(output1, reinterpret_cast<void *>(exe2_output_buffer), 16); + + execution1.execute(); + execution2.execute(); + + for (auto i = 0; i < 4; i++) + { + EXPECT_EQ(exe1_output_buffer[i], exe1_output_expected[i]); + EXPECT_EQ(exe2_output_buffer[i], exe2_output_expected[i]); + } +} + +TEST(ExecInstance, quantModel_floatIO) +{ + auto mockup = CompiledMockUpQuantModel(); + auto graph = mockup.graph; + auto executors = mockup.artifact->_executors; + + auto input1 = IOIndex{0}; + auto input2 = IOIndex{1}; + auto output = IOIndex{0}; + + const float input1_buffer[4] = {1, 0, -1, -2}; + const float input2_buffer[4] = {1, -3, 2, -4}; + float output_buffer[4] = {}; + const float output_expected[4] = {5, -2, 0, -1}; + + onert::exec::Execution execution{executors}; + + execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16); + execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16); + execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16); + execution.setInputType(input1, onert::ir::TypeInfo{onert::ir::DataType::FLOAT32}); + execution.setInputType(input2, onert::ir::TypeInfo{onert::ir::DataType::FLOAT32}); + execution.setOutputType(output, onert::ir::TypeInfo{onert::ir::DataType::FLOAT32}); + execution.execute(); + + EXPECT_EQ(output_buffer[0], output_expected[0]); + EXPECT_EQ(output_buffer[1], output_expected[1]); + EXPECT_EQ(output_buffer[2], output_expected[2]); + EXPECT_EQ(output_buffer[3], output_expected[3]); +} + +class Inference +{ +public: + Inference(const float (&input1)[4], const float (&input2)[4], float (&output)[4], + std::shared_ptr<onert::exec::IExecutors> &executors) + : _input1{input1}, _input2{input2}, _output{output}, _executors{executors} + { + // DO NOTHING + } + + void inference(void) + { + auto input1 = IOIndex{0}; + auto input2 = IOIndex{1}; + auto output1 = IOIndex{0}; + + onert::exec::Execution execution{_executors}; + execution.setInput(input1, reinterpret_cast<const void *>(_input1), 16); + execution.setInput(input2, reinterpret_cast<const void *>(_input2), 16); + execution.setOutput(output1, reinterpret_cast<void *>(_output), 16); + + execution.execute(); + } + +private: + const float (&_input1)[4]; + const float (&_input2)[4]; + float (&_output)[4]; + std::shared_ptr<onert::exec::IExecutors> &_executors; +}; + +// Support multi-thread execution +TEST(ExecInstance, twoThreads) +{ + auto mockup = CompiledMockUpModel(); + auto executors = mockup.artifact->_executors; + + const float exe1_input1_buffer[4] = {1, 0, -1, -2}; + const float exe1_input2_buffer[4] = {1, -3, 2, -4}; + float exe1_output_buffer[4] = {}; + const float exe1_output_expected[4] = {5, -2, 0, -1}; + + Inference execution1{exe1_input1_buffer, exe1_input2_buffer, exe1_output_buffer, executors}; + + const float exe2_input1_buffer[4] = {2, 1, -2, 0}; + const float exe2_input2_buffer[4] = {-3, 3, 1, 2}; + float exe2_output_buffer[4] = {}; + const float exe2_output_expected[4] = {2, 5, -2, 7}; + + Inference execution2{exe2_input1_buffer, exe2_input2_buffer, exe2_output_buffer, executors}; + + std::thread t1{&Inference::inference, &execution1}; + std::thread t2{&Inference::inference, &execution2}; + + t1.join(); + t2.join(); + + for (auto i = 0; i < 4; i++) + { + EXPECT_EQ(exe1_output_buffer[i], exe1_output_expected[i]); + EXPECT_EQ(exe2_output_buffer[i], exe2_output_expected[i]); + } +} + +// Support asynchronous execution +TEST(ExecInstance, async) +{ + auto mockup = CompiledMockUpModel(); + auto graph = mockup.graph; + auto executors = mockup.artifact->_executors; + + auto input1 = IOIndex{0}; + auto input2 = IOIndex{1}; + auto output = IOIndex{0}; + + const float input1_buffer[4] = {1, 0, -1, -2}; + const float input2_buffer[4] = {1, -3, 2, -4}; + float output_buffer[4] = {}; + const float output_expected[4] = {5, -2, 0, -1}; + + onert::exec::Execution execution{executors}; + + execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16); + execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16); + execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16); + execution.startExecute(); + execution.waitFinish(); + + for (auto i = 0; i < 4; i++) + { + EXPECT_EQ(output_buffer[i], output_expected[i]); + } +} + +TEST(ExecInstance, multi_model_simple) +{ + auto mockup = CompiledMockUpMultiModel(); + auto executors = mockup.artifact->_executors; + + auto input1 = IOIndex{0}; + auto input2 = IOIndex{1}; + auto output = IOIndex{0}; + + const float input1_buffer[4] = {1, 0, -1, -2}; + const float input2_buffer[4] = {1, -3, 2, -4}; + float output_buffer[4] = {}; + const float output_expected[4] = {7, -5, 1, -7}; + + onert::exec::Execution execution{executors}; + + execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16); + execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16); + execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16); + execution.execute(); + + for (auto i = 0; i < 4; i++) + { + EXPECT_EQ(output_buffer[i], output_expected[i]); + } +} + +TEST(ExecInstance, multi_model_twoCompile) +{ + auto mockup = CompiledMockUpMultiModel(); + auto executors1 = mockup.artifact->_executors; + onert::exec::Execution execution1{executors1}; + + auto input1 = IOIndex{0}; + auto input2 = IOIndex{1}; + auto output = IOIndex{0}; + + const float exe1_input1_buffer[4] = {1, 0, -1, -2}; + const float exe1_input2_buffer[4] = {1, -3, 2, -4}; + float exe1_output_buffer[4] = {}; + const float exe1_output_expected[4] = {7, -5, 1, -7}; + + execution1.setInput(input1, reinterpret_cast<const void *>(exe1_input1_buffer), 16); + execution1.setInput(input2, reinterpret_cast<const void *>(exe1_input2_buffer), 16); + execution1.setOutput(output, reinterpret_cast<void *>(exe1_output_buffer), 16); + + // Make new executor: compile again + mockup.compile(); + onert::exec::Execution execution2{mockup.artifact->_executors}; + + const float exe2_input1_buffer[4] = {2, 1, -2, 0}; + const float exe2_input2_buffer[4] = {-3, 3, 1, 2}; + float exe2_output_buffer[4] = {}; + const float exe2_output_expected[4] = {1, 9, -3, 9}; + + execution2.setInput(input1, reinterpret_cast<const void *>(exe2_input1_buffer), 16); + execution2.setInput(input2, reinterpret_cast<const void *>(exe2_input2_buffer), 16); + execution2.setOutput(output, reinterpret_cast<void *>(exe2_output_buffer), 16); + + execution1.execute(); + execution2.execute(); + + for (auto i = 0; i < 4; i++) + { + EXPECT_EQ(exe1_output_buffer[i], exe1_output_expected[i]); + EXPECT_EQ(exe2_output_buffer[i], exe2_output_expected[i]); + } +} + +// Support two initialized execution instance then ordered execution +TEST(ExecInstance, multi_model_twoExecution) +{ + auto mockup = CompiledMockUpMultiModel(); + auto executors = mockup.artifact->_executors; + auto input1 = IOIndex{0}; + auto input2 = IOIndex{1}; + auto output1 = IOIndex{0}; + + const float exe1_input1_buffer[4] = {1, 0, -1, -2}; + const float exe1_input2_buffer[4] = {1, -3, 2, -4}; + float exe1_output_buffer[4] = {}; + const float exe1_output_expected[4] = {7, -5, 1, -7}; + const float exe2_output_expected[4] = {1, 9, -3, 9}; + + onert::exec::Execution execution1{executors}; + execution1.setInput(input1, reinterpret_cast<const void *>(exe1_input1_buffer), 16); + execution1.setInput(input2, reinterpret_cast<const void *>(exe1_input2_buffer), 16); + execution1.setOutput(output1, reinterpret_cast<void *>(exe1_output_buffer), 16); + + const float exe2_input1_buffer[4] = {2, 1, -2, 0}; + const float exe2_input2_buffer[4] = {-3, 3, 1, 2}; + float exe2_output_buffer[4] = {}; + + // Make new execution + onert::exec::Execution execution2{executors}; + execution2.setInput(input1, reinterpret_cast<const void *>(exe2_input1_buffer), 16); + execution2.setInput(input2, reinterpret_cast<const void *>(exe2_input2_buffer), 16); + execution2.setOutput(output1, reinterpret_cast<void *>(exe2_output_buffer), 16); + + execution1.execute(); + execution1.execute(); + execution2.execute(); + execution2.execute(); + + for (auto i = 0; i < 4; i++) + { + EXPECT_EQ(exe1_output_buffer[i], exe1_output_expected[i]); + EXPECT_EQ(exe2_output_buffer[i], exe2_output_expected[i]); + } +} + +// Multi-model is not thread-safe yet + +// Support asynchronous execution +TEST(ExecInstance, multi_model_async) +{ + auto mockup = CompiledMockUpMultiModel(); + auto executors = mockup.artifact->_executors; + + auto input1 = IOIndex{0}; + auto input2 = IOIndex{1}; + auto output = IOIndex{0}; + + const float input1_buffer[4] = {1, 0, -1, -2}; + const float input2_buffer[4] = {1, -3, 2, -4}; + float output_buffer[4] = {}; + const float output_expected[4] = {7, -5, 1, -7}; + + onert::exec::Execution execution{executors}; + + execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16); + execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16); + execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16); + execution.startExecute(); + execution.waitFinish(); + + for (auto i = 0; i < 4; i++) + { + EXPECT_EQ(output_buffer[i], output_expected[i]); + } +} + +TEST(ExecInstance, multi_model_dequant_input_quant_output) +{ + auto mockup = CompiledMockUpMultiModel(); + auto executors = mockup.artifact->_executors; + + auto input1 = IOIndex{0}; + auto input2 = IOIndex{1}; + auto output = IOIndex{0}; + + const uint8_t input1_buffer[4] = {138, 128, 118, 108}; // {1, 0, -1, -2} + const uint8_t input2_buffer[4] = {138, 98, 148, 88}; // {1, -3, 2, -4} + uint8_t output_buffer[4] = {}; + const uint8_t output_expected[4] = {198, 78, 138, 58}; // {7, -5, 1, -7} + float scale = 0.1; + int32_t zero_point = 128; + + onert::exec::Execution execution{executors}; + + onert::ir::TypeInfo type_info{onert::ir::DataType::QUANT_UINT8_ASYMM, scale, zero_point}; + execution.setInputType(input1, type_info); + execution.setInput(input1, execution.getInputShape(input1), + reinterpret_cast<const void *>(input1_buffer), 4); + execution.setInputType(input2, type_info); + execution.setInput(input2, execution.getInputShape(input2), + reinterpret_cast<const void *>(input2_buffer), 4); + execution.setOutputType(output, type_info); + execution.setOutput(output, execution.getOutputShape(output), + reinterpret_cast<void *>(output_buffer), 4); + execution.execute(); + + for (auto i = 0; i < 4; i++) + { + EXPECT_EQ(output_buffer[i], output_expected[i]); + } +} + +// TODO Add an unittest multi_model_quant_input_dequant_output + +} // namespace diff --git a/runtime/onert/core/src/exec/ExecutionContext.cc b/runtime/onert/core/src/exec/ExecutionContext.cc new file mode 100644 index 000000000..aec10ee5b --- /dev/null +++ b/runtime/onert/core/src/exec/ExecutionContext.cc @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "exec/ExecutionContext.h" + +#include "util/ConfigSource.h" + +namespace onert +{ +namespace exec +{ + +void ExecutionOptions::fromGlobalConfig(ExecutionOptions &options) +{ + options.dump_minmax = util::getConfigBool(util::config::MINMAX_DUMP); + options.trace = util::getConfigBool(util::config::TRACING_MODE); + options.profile = util::getConfigBool(util::config::PROFILING_MODE); +} + +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/ExecutionObservee.cc b/runtime/onert/core/src/exec/ExecutionObservee.cc index ddb1fb6a0..22881b8c8 100644 --- a/runtime/onert/core/src/exec/ExecutionObservee.cc +++ b/runtime/onert/core/src/exec/ExecutionObservee.cc @@ -21,42 +21,72 @@ namespace onert namespace exec { -void ExecutionObservee::add(std::unique_ptr<IExecutionObserver> observer) +ExecutionObservee::ExecutionObservee(const ExecObservers &observers, + const ExecutionOptions &options) { - _observers.emplace_back(std::move(observer)); + // TODO Use execution option + if (options.dump_minmax) + { + auto observer = observers.get(ObserverType::MINMAX_DUMP); + if (!observer) + throw std::runtime_error{"MinMaxRecorder is only supported on LinearExecutor, single model"}; + + _observers.emplace_back(observer); + } + + if (options.trace) + { + auto observer = observers.get(ObserverType::TRACING); + if (!observer) + throw std::runtime_error{"Cannot find TracingObserver"}; + + _observers.emplace_back(observer); + } + + if (options.profile) + { + auto observer = observers.get(ObserverType::PROFILE); + if (!observer) + throw std::runtime_error{ + "Profiling is only supported on DataflowExecutor with heterogenous scheduler"}; + + _observers.emplace_back(observer); + } } -void ExecutionObservee::notifyModelBegin(IExecutor *executor) +void ExecutionObservee::notifySubgraphBegin(ir::SubgraphIndex ind) const { - for (auto &o : _observers) + for (auto &&o : _observers) { - o->handleBegin(executor); + o->handleSubgraphBegin(ind); } } -void ExecutionObservee::notifyModelEnd(IExecutor *executor) +void ExecutionObservee::notifySubgraphEnd(ir::SubgraphIndex ind) const { - for (auto &o : _observers) + for (auto &&o : _observers) { - o->handleEnd(executor); + o->handleSubgraphEnd(ind); } } -void ExecutionObservee::notifyJobBegin(IExecutor *executor, const ir::OpSequence *op_seq, - const backend::Backend *backend) +void ExecutionObservee::notifyJobBegin(IExecutor *executor, ir::SubgraphIndex subg_ind, + ir::OperationIndex op_ind, + const backend::Backend *backend) const { - for (auto &o : _observers) + for (auto &&o : _observers) { - o->handleBegin(executor, op_seq, backend); + o->handleJobBegin(executor, subg_ind, op_ind, backend); } } -void ExecutionObservee::notifyJobEnd(IExecutor *executor, const ir::OpSequence *op_seq, - const backend::Backend *backend) +void ExecutionObservee::notifyJobEnd(IExecutor *executor, ir::SubgraphIndex subg_ind, + ir::OperationIndex op_ind, + const backend::Backend *backend) const { - for (auto &o : _observers) + for (auto &&o : _observers) { - o->handleEnd(executor, op_seq, backend); + o->handleJobEnd(executor, subg_ind, op_ind, backend); } } diff --git a/runtime/onert/core/src/exec/ExecutionObservee.h b/runtime/onert/core/src/exec/ExecutionObservee.h index 49d409a3a..e6461c788 100644 --- a/runtime/onert/core/src/exec/ExecutionObservee.h +++ b/runtime/onert/core/src/exec/ExecutionObservee.h @@ -17,9 +17,11 @@ #ifndef __ONERT_EXEC_EXECUTION_OBSERVEE_H__ #define __ONERT_EXEC_EXECUTION_OBSERVEE_H__ -#include <list> +#include "ExecutionObservers.h" + +#include "ir/Index.h" -#include "exec/ExecutionObservers.h" +#include <list> namespace onert { @@ -34,20 +36,21 @@ class ExecutionObservee { public: /** - * @brief Register an observer + * @brief Register enabled observers * - * @param observer Observer to be added + * @param observer Observers generated by compiler */ - void add(std::unique_ptr<IExecutionObserver> observer); - void notifyModelBegin(IExecutor *executor); - void notifyModelEnd(IExecutor *executor); - void notifyJobBegin(IExecutor *executor, const ir::OpSequence *op_seq, - const backend::Backend *backend); - void notifyJobEnd(IExecutor *executor, const ir::OpSequence *op_seq, - const backend::Backend *backend); + ExecutionObservee(const ExecObservers &observers, const ExecutionOptions &options); + void notifySubgraphBegin(ir::SubgraphIndex ind) const; + void notifySubgraphEnd(ir::SubgraphIndex ind) const; + void notifyJobBegin(IExecutor *executor, ir::SubgraphIndex subg_ind, ir::OperationIndex op_ind, + const backend::Backend *backend) const; + void notifyJobEnd(IExecutor *executor, ir::SubgraphIndex subg_ind, ir::OperationIndex op_ind, + const backend::Backend *backend) const; + bool isEmpty() const { return _observers.size() == 0; } private: - std::list<std::unique_ptr<IExecutionObserver>> _observers; + std::list<IExecutionObserver *> _observers; }; } // namespace exec diff --git a/runtime/onert/core/src/exec/ExecutionObservers.cc b/runtime/onert/core/src/exec/ExecutionObservers.cc index 060f874de..a58daeabd 100644 --- a/runtime/onert/core/src/exec/ExecutionObservers.cc +++ b/runtime/onert/core/src/exec/ExecutionObservers.cc @@ -14,14 +14,58 @@ * limitations under the License. */ -#include "exec/ExecutionObservers.h" +#include "ExecutionObservers.h" -#include <string> +#include "../util/EventWriter.h" #include "util/logging.h" -#include "exec/IExecutor.h" -#include "misc/polymorphic_downcast.h" -#include "ir/OpSequence.h" + +#include <misc/polymorphic_downcast.h> + +#include <string> +#include <sstream> + +namespace +{ + +void setUserData(const onert::ir::Graph &g, const onert::ir::IOperation *op, + decltype(EventCollector::Event::userData) &data) +{ + // From a tensor of shape [a, b, c], this will return a string "shape(a b c)". + // String like "[1, 2, 3]" looks better but this will be considered as a list in Json + // so text search (e.g., Ctrl-F in Chrome Tracing) could be difficult + auto build_shape_str = [&](onert::ir::OperandIndex operand_idx) { + std::string shape_str; + auto &shape = g.operands().at(operand_idx).info().shape(); + for (int i = 0; i < shape.rank(); i++) + { + if (i == 0) + shape_str = "shape(" + std::to_string(shape.dim(i)); + else + shape_str += " " + std::to_string(shape.dim(i)); + } + shape_str += ")"; + + return shape_str; + }; + + auto &inputs = op->getInputs(); + auto size = inputs.size(); + for (size_t i = 0; i < size; i++) + { + auto operand_idx = inputs.at(i); + if (operand_idx.undefined()) + continue; + + std::string key("input_shape_" + std::to_string(i)); + std::string value = build_shape_str(operand_idx); + data.emplace_back(std::make_pair(key, value)); + } + + // add other userData as needed +} + +} // namespace namespace onert { @@ -29,8 +73,8 @@ namespace onert namespace exec { -void ProfileObserver::handleBegin(onert::exec::IExecutor *, const ir::OpSequence *, - const onert::backend::Backend *backend) +void ProfileObserver::handleJobBegin(onert::exec::IExecutor *, ir::SubgraphIndex, + ir::OperationIndex, const onert::backend::Backend *backend) { _timer = backend->config()->timer(); if (_timer == nullptr) @@ -38,14 +82,14 @@ void ProfileObserver::handleBegin(onert::exec::IExecutor *, const ir::OpSequence _timer->handleBegin(); } -void ProfileObserver::handleEnd(IExecutor *exec, const ir::OpSequence *op_seq, - const backend::Backend *backend) +void ProfileObserver::handleJobEnd(IExecutor *exec, ir::SubgraphIndex, + const ir::OperationIndex op_ind, const backend::Backend *backend) { _timer->handleEnd(); const auto timer_res = _timer->getTime(); - // NOTE This assumes there is just one operation in a op_seq - const auto &node = _graph.operations().at(op_seq->operations().at(0)); + // NOTE This assumes there is just one operation in a op + const auto &node = _graph.operations().at(op_ind); auto node_name = node.name(); VERBOSE(ProfileInfo) << "Time for " << node_name << " : " << timer_res << std::endl; @@ -54,7 +98,7 @@ void ProfileObserver::handleEnd(IExecutor *exec, const ir::OpSequence *op_seq, ir::DataType::QUANT_UINT8_ASYMM; uint32_t size = 0; - for (const auto &ind : node.getInputs() + node.getOutputs()) + for (const auto &ind : (node.getInputs() + node.getOutputs()) | ir::Remove::UNDEFINED) { size += exec->graph().operands().at(ind).info().total_size(); } @@ -69,64 +113,66 @@ void ProfileObserver::handleEnd(IExecutor *exec, const ir::OpSequence *op_seq, } }; -ChromeTracingObserver::ChromeTracingObserver(const std::string &filepath, const ir::Graph &graph) - : _ofs{filepath, std::ofstream::out}, _recorder{}, _collector{&_recorder}, _graph{graph} +TracingObserver::TracingObserver(const std::string &workspace_dir, const ir::Graph &graph, + const util::TracingCtx *tracing_ctx) + : _recorder{std::make_unique<EventRecorder>()}, _collector{_recorder.get()}, _graph{graph}, + _workspace_dir{workspace_dir}, _tracing_ctx{tracing_ctx}, _triggered{false} { + // DO NOTHING } -ChromeTracingObserver::~ChromeTracingObserver() +TracingObserver::~TracingObserver() { try { - _recorder.writeToFile(_ofs); + // Write file if this observer is triggered at least once + if (_triggered) + { + auto event_writer = EventWriter::get(_workspace_dir); + event_writer->startToUse(); + event_writer->readyToFlush(std::move(_recorder)); + } } catch (const std::exception &e) { - std::cerr << "E: Fail to record event in ChromeTracingObserver: " << e.what() << std::endl; + std::cerr << "E: Fail to record event in TracingObserver: " << e.what() << std::endl; } } -void ChromeTracingObserver::handleBegin(IExecutor *) +void TracingObserver::handleSubgraphBegin(ir::SubgraphIndex subg_ind) { - _collector.onEvent(EventCollector::Event{EventCollector::Edge::BEGIN, "runtime", "Graph"}); -} + _triggered = true; -void ChromeTracingObserver::handleBegin(IExecutor *, const ir::OpSequence *op_seq, - const backend::Backend *backend) -{ - std::string backend_id = backend->config()->id(); - _collector.onEvent(EventCollector::Event{EventCollector::Edge::BEGIN, backend_id, - opSequenceTag(op_seq, _graph.operations())}); + _collector.onEvent( + EventCollector::SubgEvent{_tracing_ctx, EventCollector::Edge::BEGIN, subg_ind.value()}); } -void ChromeTracingObserver::handleEnd(IExecutor *, const ir::OpSequence *op_seq, - const backend::Backend *backend) +void TracingObserver::handleJobBegin(IExecutor *, ir::SubgraphIndex subg_ind, + ir::OperationIndex op_ind, const backend::Backend *backend) { std::string backend_id = backend->config()->id(); - _collector.onEvent(EventCollector::Event{EventCollector::Edge::END, backend_id, - opSequenceTag(op_seq, _graph.operations())}); + const auto &op = _graph.operations().at(op_ind); + auto ev = EventCollector::OpSeqEvent{_tracing_ctx, EventCollector::Edge::BEGIN, + subg_ind.value(), backend_id, + op_ind.value(), op.name()}; + // add shape of inputs + setUserData(_graph, &op, ev.userData); + _collector.onEvent(ev); } -void ChromeTracingObserver::handleEnd(IExecutor *) +void TracingObserver::handleJobEnd(IExecutor *, ir::SubgraphIndex subg_ind, + ir::OperationIndex op_ind, const backend::Backend *backend) { - _collector.onEvent(EventCollector::Event{EventCollector::Edge::END, "runtime", "Graph"}); + std::string backend_id = backend->config()->id(); + _collector.onEvent(EventCollector::OpSeqEvent{_tracing_ctx, EventCollector::Edge::END, + subg_ind.value(), backend_id, op_ind.value(), + _graph.operations().at(op_ind).name()}); } -std::string ChromeTracingObserver::opSequenceTag(const ir::OpSequence *op_seq, - const ir::Operations &operations) +void TracingObserver::handleSubgraphEnd(ir::SubgraphIndex subg_ind) { - if (op_seq->size() == 0) - return "Empty OpSequence"; - - const auto &first_op_idx = op_seq->operations().at(0); - const auto &first_op_node = operations.at(first_op_idx); - std::string tag = "$" + std::to_string(first_op_idx.value()); - tag += " " + first_op_node.name(); - if (op_seq->size() > 1) - { - tag += " (+" + std::to_string(op_seq->size() - 1) + ")"; - } - return tag; + _collector.onEvent( + EventCollector::SubgEvent{_tracing_ctx, EventCollector::Edge::END, subg_ind.value()}); } } // namespace exec diff --git a/runtime/onert/core/src/exec/ExecutionObservers.h b/runtime/onert/core/src/exec/ExecutionObservers.h index ac0076ed2..e59d58766 100644 --- a/runtime/onert/core/src/exec/ExecutionObservers.h +++ b/runtime/onert/core/src/exec/ExecutionObservers.h @@ -17,44 +17,82 @@ #ifndef __ONERT_EXEC_OBSREVERS_H__ #define __ONERT_EXEC_OBSREVERS_H__ -#include "exec/IFunction.h" -#include "ir/OpSequence.h" #include "ExecTime.h" -#include "util/ITimer.h" +#include "../util/EventCollector.h" +#include "../util/EventRecorder.h" +#include "../util/EventWriter.h" + #include "exec/IExecutor.h" -#include "util/EventCollector.h" -#include "util/EventRecorder.h" +#include "ir/Index.h" +#include "ir/IOperation.h" +#include "util/ITimer.h" +#include "util/TracingCtx.h" namespace onert { namespace exec { + +enum class ObserverType +{ + PROFILE, + TRACING, + MINMAX_DUMP, +}; + class IExecutionObserver { public: /// @brief Invoked just before model (not individual operation) execution begins - virtual void handleBegin(IExecutor *) { return; } + virtual void handleSubgraphBegin(ir::SubgraphIndex) { return; } - virtual void handleBegin(IExecutor *, const ir::OpSequence *, const backend::Backend *) = 0; - virtual void handleEnd(IExecutor *, const ir::OpSequence *, const backend::Backend *) = 0; + virtual void handleJobBegin(IExecutor *, ir::SubgraphIndex, ir::OperationIndex, + const backend::Backend *) = 0; + virtual void handleJobEnd(IExecutor *, ir::SubgraphIndex, ir::OperationIndex, + const backend::Backend *) = 0; /// @brief Invoked just after model (not individual operation) execution ends - virtual void handleEnd(IExecutor *) { return; } + virtual void handleSubgraphEnd(ir::SubgraphIndex) { return; } + + virtual ObserverType type() const = 0; virtual ~IExecutionObserver() = default; }; +class ExecObservers +{ +public: + void add(std::unique_ptr<IExecutionObserver> &&observer) + { + _observers.emplace(observer->type(), std::move(observer)); + } + + IExecutionObserver *get(ObserverType type) const + { + if (_observers.find(type) != _observers.end()) + return _observers.at(type).get(); + + return nullptr; + } + +private: + std::unordered_map<ObserverType, std::unique_ptr<IExecutionObserver>> _observers; +}; + class ProfileObserver : public IExecutionObserver { public: explicit ProfileObserver(std::shared_ptr<ExecTime> et, const ir::Graph &graph) - : _et(std::move(et)), _graph(graph) + : _et(std::move(et)), _graph(graph) { } - void handleBegin(IExecutor *, const ir::OpSequence *, const backend::Backend *) override; - void handleEnd(IExecutor *, const ir::OpSequence *, const backend::Backend *) override; + void handleJobBegin(IExecutor *, ir::SubgraphIndex, ir::OperationIndex, + const backend::Backend *) override; + void handleJobEnd(IExecutor *, ir::SubgraphIndex, ir::OperationIndex, + const backend::Backend *) override; - void handleEnd(IExecutor *) override { _et->uploadOperationsExecTime(); } + void handleSubgraphEnd(ir::SubgraphIndex) override { _et->storeOperationsExecTime(); } + ObserverType type() const override { return ObserverType::PROFILE; } private: std::unique_ptr<util::ITimer> _timer; @@ -62,24 +100,27 @@ private: const ir::Graph &_graph; }; -class ChromeTracingObserver : public IExecutionObserver +class TracingObserver : public IExecutionObserver { public: - ChromeTracingObserver(const std::string &filepath, const ir::Graph &graph); - ~ChromeTracingObserver(); - void handleBegin(IExecutor *) override; - void handleBegin(IExecutor *, const ir::OpSequence *, const backend::Backend *) override; - void handleEnd(IExecutor *, const ir::OpSequence *, const backend::Backend *) override; - void handleEnd(IExecutor *) override; - -private: - static std::string opSequenceTag(const ir::OpSequence *op_seq, const ir::Operations &operations); + TracingObserver(const std::string &workspace_dir, const ir::Graph &graph, + const util::TracingCtx *tracing_ctx); + ~TracingObserver(); + void handleSubgraphBegin(ir::SubgraphIndex) override; + void handleJobBegin(IExecutor *, ir::SubgraphIndex, ir::OperationIndex, + const backend::Backend *) override; + void handleJobEnd(IExecutor *, ir::SubgraphIndex, ir::OperationIndex, + const backend::Backend *) override; + void handleSubgraphEnd(ir::SubgraphIndex) override; + ObserverType type() const override { return ObserverType::TRACING; } private: - std::ofstream _ofs; - EventRecorder _recorder; + std::unique_ptr<EventRecorder> _recorder; EventCollector _collector; const ir::Graph &_graph; + std::string _workspace_dir; + const util::TracingCtx *_tracing_ctx; + bool _triggered; }; } // namespace exec diff --git a/runtime/onert/core/src/exec/ExecutorBase.cc b/runtime/onert/core/src/exec/ExecutorBase.cc index f835a9675..2526e4e6e 100644 --- a/runtime/onert/core/src/exec/ExecutorBase.cc +++ b/runtime/onert/core/src/exec/ExecutorBase.cc @@ -16,10 +16,10 @@ #include "ExecutorBase.h" -#include "backend/ITensor.h" -#include "backend/controlflow/UserTensor.h" -#include "backend/cpu_common/Tensor.h" -#include "util/logging.h" +#include "ShapeConverter.h" + +#include "util/ConfigSource.h" +#include <misc/polymorphic_downcast.h> namespace onert { @@ -27,214 +27,68 @@ namespace exec { ExecutorBase::ExecutorBase(std::unique_ptr<compiler::LoweredGraph> &&lowered_graph, - const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors, - const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors, + backend::BackendContexts &&backend_contexts, const compiler::TensorRegistries &tensor_regs, - backend::TensorManagerSet &&tensor_mgrs) - : _lowered_graph{std::move(lowered_graph)}, _graph{_lowered_graph->graph()}, - _input_tensors{input_tensors}, _output_tensors{output_tensors}, - _tensor_mgrs{std::move(tensor_mgrs)}, _mutex() + const util::TracingCtx *tracing_ctx) + : _lowered_graph{std::move(lowered_graph)}, _backend_contexts{std::move(backend_contexts)}, + _graph{_lowered_graph->graph()}, _mutex(), _tracing_ctx(tracing_ctx) { - // TODO Fix the way of knowing whether it is primary or not - bool primary_executor = !(_input_tensors.empty() && _output_tensors.empty()); - if (!primary_executor) - { - auto build_input_tensor_list = [&](const onert::ir::OperandIndexSequence &ind_seq) { - std::vector<std::shared_ptr<backend::ITensor>> list; - for (auto ind : ind_seq) - { - std::shared_ptr<backend::ITensor> tensor = tensor_regs.getITensor(ind); - assert(tensor != nullptr); - DynAllocInfo dyn_alloc_info{ind}; - _input_to_dyn_alloc_info.emplace(tensor, dyn_alloc_info); - list.push_back(tensor); - } - return list; - }; - auto build_output_tensor_list = [&](const onert::ir::OperandIndexSequence &ind_seq) { - std::vector<std::shared_ptr<backend::ITensor>> list; - for (auto ind : ind_seq) - { - std::shared_ptr<backend::ITensor> tensor = tensor_regs.getITensor(ind); - assert(tensor != nullptr); - DynAllocInfo dyn_alloc_info{ind}; - _output_to_dyn_alloc_info.emplace(tensor, dyn_alloc_info); - list.push_back(tensor); - } - return list; - }; - _input_tensors = build_input_tensor_list(_graph.getInputs()); - _output_tensors = build_output_tensor_list(_graph.getOutputs()); - } - else - { - assert(input_tensors.size() == _graph.getInputs().size()); - assert(output_tensors.size() == _graph.getOutputs().size()); - for (uint32_t i = 0; i < input_tensors.size(); i++) - { - auto tensor = input_tensors[i]; - auto ind = _graph.getInputs().at(i); - DynAllocInfo dyn_alloc_info{ind}; - _input_to_dyn_alloc_info.emplace(tensor, dyn_alloc_info); - } - for (uint32_t i = 0; i < output_tensors.size(); i++) + auto build_tensor_list = [&](const auto &ind_seq, auto &tensors) { + assert(tensors.empty()); + for (auto &&ind : ind_seq) { - auto tensor = output_tensors[i]; - auto ind = _graph.getOutputs().at(i); - DynAllocInfo dyn_alloc_info{ind}; - _output_to_dyn_alloc_info.emplace(tensor, dyn_alloc_info); + backend::ITensor *tensor = tensor_regs.getITensor(ind); + assert(tensor != nullptr); + auto io_tensor = nnfw::misc::polymorphic_downcast<backend::builtin::IOTensor *>(tensor); + tensors.push_back(io_tensor); } - } + }; + build_tensor_list(_graph.getInputs(), _input_tensors); + build_tensor_list(_graph.getOutputs(), _output_tensors); } -void ExecutorBase::execute(const std::vector<std::shared_ptr<backend::ITensor>> &src_tensors, - const std::shared_ptr<IPermuteFunction> &pre_fn) +void ExecutorBase::execute(const std::vector<backend::IPortableTensor *> &inputs, + const std::vector<backend::IPortableTensor *> &outputs, + const ExecutionOptions &options) { // For thread-safe, use mutex // TODO: if all used backends on this executor are thread-safe, // do not need to use mutex (otherwise, use mutex) // Deadlock occurs when an Executor is called recursively. std::lock_guard<std::mutex> lock(_mutex); + _current_options = options; - assert(src_tensors.size() == _graph.getInputs().size()); - assert(src_tensors.size() == _input_tensors.size()); - for (uint32_t n = 0; n < _graph.getInputs().size(); ++n) + assert(inputs.size() == _graph.getInputs().size()); + assert(inputs.size() == _input_tensors.size()); + for (uint32_t n = 0; n < inputs.size(); ++n) { - // when user changes input shape, the input tensor is dynamic and its memory is not allocated. - // This code find the info to allocate dynamic tensor, and allocate memory based on the source - // tensor's shape set by caller. - const auto src_tensor = src_tensors[n]; + const auto input = inputs[n]; + assert(input->buffer() != nullptr || input->get_info().total_size() == 0); auto input_tensor = _input_tensors[n]; - // If src_tensor or input_tensor is nullptr, pre_fn does not copy the tensors - if (src_tensor != nullptr && input_tensor != nullptr) - { - auto dyn_alloc_info = _input_to_dyn_alloc_info.find(_input_tensors[n]); - const auto orig_input_shape = input_tensor->getShape(); - const auto changed_input_shape = - convertShape(src_tensor->getShape(), src_tensor->layout(), input_tensor->layout()); - if (orig_input_shape != changed_input_shape) - { - if (dyn_alloc_info == _input_to_dyn_alloc_info.end()) - { - // The input_tensor is a dynamic tensor of backend that doesn't support dynamic tensor - throw std::runtime_error("Unknown dim is found at execution time for a backend that " - "does not support dynamic tensor"); - } - else - { - input_tensor->set_dynamic(); - } - } - } + assert(input_tensor != nullptr); + input_tensor->setTensor(input); } - // TODO Move calling permute_fn.run() into executeImpl() - assert(pre_fn); - pre_fn->run(); - - executeImpl(); -} - -void ExecutorBase::execute(const IODescription &desc) -{ - // For thread-safe, use mutex - // TODO: if all used backends on this executor are thread-safe, - // do not need to use mutex (otherwise, use mutex) - std::lock_guard<std::mutex> lock(_mutex); - - // Set input(s) - assert(_input_tensors.size() == desc.inputs.size()); - for (uint32_t i = 0; i < _input_tensors.size(); ++i) + assert(outputs.size() == _graph.getOutputs().size()); + assert(outputs.size() == _output_tensors.size()); + for (uint32_t n = 0; n < outputs.size(); ++n) { - // TODO Remove dynamic_cast - auto tensor = std::dynamic_pointer_cast<backend::controlflow::UserTensor>(_input_tensors[i]); - assert(tensor); - auto input_shape = desc.dynamic_input_shapes.find(ir::IOIndex{i}); - if (input_shape != desc.dynamic_input_shapes.end()) - { - tensor->set_dynamic(); - tensor->setShape(input_shape->second); - } - // TODO Better design for ITensor? (we need const_cast as ITensor is writable) - tensor->setBuffer(static_cast<uint8_t *>(const_cast<void *>(desc.inputs[i]->buffer)), - desc.inputs[i]->size); - - handleDynamicInputTensor(ir::IOIndex{i}, desc); + const auto output = outputs[n]; + assert(output->buffer() != nullptr || output->get_info().total_size() == 0); + auto output_tensor = _output_tensors[n]; + assert(output_tensor != nullptr); + output_tensor->setTensor(output); } - assert(_output_tensors.size() == desc.outputs.size()); - for (uint32_t i = 0; i < _output_tensors.size(); ++i) - { - // TODO Remove dynamic_cast - auto tensor = std::dynamic_pointer_cast<backend::controlflow::UserTensor>(_output_tensors[i]); - assert(tensor); - tensor->set_dynamic(); // It can't be resized but shape could change - // TODO Better design for ITensor? (we need const_cast as ITensor is writable) - tensor->setBuffer(static_cast<uint8_t *>(const_cast<void *>(desc.outputs[i]->buffer)), - desc.outputs[i]->size); - } - - executeImpl(); - - // Update output(s) desc - for (uint32_t n = 0; n < _graph.getOutputs().size(); ++n) - { - ir::IOIndex output_index{n}; - // Optional output - if (desc.outputs.at(n) == nullptr) - { - continue; - } - auto &output = *desc.outputs.at(n); - - // set shape of outputDesc to tensor shape since tensor can be dynamic - const auto output_tensor_shape = _output_tensors[n]->getShape(); - output.info.shape( - convertShape(output_tensor_shape, _output_tensors[n]->layout(), output.layout)); - } -} + // Create observee + ExecutionObservee subject(_observers, options); -/** - * @brief Changes tensor shape and allocate memory - * if input shape was changed by nnfw_set_input_tensorinfo() - * - * @note Cases are: - * 1) static operand -> nnfw_set_input_tensorinfo() -> execute() -> execute() - * (a) (b) - * - * at (a), operand is static, tensor is static - memory dealloc is not needed - * (DynamicTensorManager cannot dealloc memory allocated by StaticTensorManager) - * at (b), operand is static, tensor is dynamic - memory dealloc is needed - * - * 2) dynamic operand -> nnfw_set_input_tensorinfo() -> execute() -> execute() - * (a) (b) - * - * at (a), operand is dynamic, tensor is dynamic - memory dealloc is not needed - * since it has not been allocated yet - * at (b), operand is dynamic, tensor is dynamic - memory dealloc is needed - */ -void ExecutorBase::handleDynamicInputTensor(ir::IOIndex io_ind, const IODescription &desc) -{ - auto shape_sig_found = desc.dynamic_input_shapes.find(io_ind); - if (shape_sig_found != desc.dynamic_input_shapes.end()) - { - auto dyn_alloc_info = _input_to_dyn_alloc_info.find(_input_tensors[io_ind.value()]); - if (dyn_alloc_info == _input_to_dyn_alloc_info.end()) - throw std::runtime_error("Unknown dim is found at execution time for a backend that " - "does not support dynamic tensor"); - - auto changed_input_shape = shape_sig_found->second; - auto operand_ind = dyn_alloc_info->second.ind; - - auto dyn_tensor_manager = _input_tensors[io_ind.value()]->dynamic_tensor_manager(); - assert(dyn_tensor_manager); - dyn_tensor_manager->applyShape(operand_ind, changed_input_shape); - } + executeImpl(subject); } bool ExecutorBase::hasDynamicInput() { - for (auto &tensor : _input_tensors) + for (auto &&tensor : _input_tensors) { if (tensor->is_dynamic()) return true; diff --git a/runtime/onert/core/src/exec/ExecutorBase.h b/runtime/onert/core/src/exec/ExecutorBase.h index a13be7dbf..2ae63ddd4 100644 --- a/runtime/onert/core/src/exec/ExecutorBase.h +++ b/runtime/onert/core/src/exec/ExecutorBase.h @@ -17,25 +17,20 @@ #ifndef __ONERT_EXEC_EXECUTOR_BASE_H__ #define __ONERT_EXEC_EXECUTOR_BASE_H__ -#include <mutex> +#include "ExecutionObservee.h" +#include "../backend/builtin/IOTensor.h" +#include "../compiler/TensorRegistries.h" -#include "IPermuteFunction.h" -#include "Source.h" -#include "exec/ExecutionObservers.h" -#include "Sink.h" -#include "ShapeConverter.h" -#include "exec/IExecutor.h" #include "compiler/LoweredGraph.h" -#include "ir/LowerInfoMap.h" -#include "backend/IConfig.h" -#include "backend/Backend.h" -#include "exec/ExecTime.h" -#include "exec/IFunction.h" -#include "backend/IDynamicTensorManager.h" -#include "backend/ITensorManager.h" -#include "exec/ExecutionObservee.h" -#include "compiler/TensorRegistries.h" -#include <list> +#include "exec/IExecutor.h" +#include "exec/ExecutionContext.h" +#include "ir/Graph.h" +#include "ir/OperationIndexMap.h" +#include "util/TracingCtx.h" + +#include <memory> +#include <mutex> +#include <vector> namespace onert { @@ -51,47 +46,51 @@ public: * @param tensor_builders Tensor builders that are currently used */ ExecutorBase(std::unique_ptr<compiler::LoweredGraph> &&lowered_graph, - const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors, - const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors, - const compiler::TensorRegistries &tensor_regs, - backend::TensorManagerSet &&tensor_mgrs); + backend::BackendContexts &&backend_contexts, + const compiler::TensorRegistries &tensor_regs, const util::TracingCtx *tracing_ctx); virtual ~ExecutorBase() = default; - const ir::Graph &graph() final { return _graph; } + const ir::Graph &graph() const final { return _graph; } - /** - * @brief Execute without IODescription - * - * @param src_tensor Tensor list that will be copied to input tensors of this - * @param pre_fn The permutation function that copy from src_tensor to input tensors of this - */ - void execute(const std::vector<std::shared_ptr<backend::ITensor>> &src_tensors, - const std::shared_ptr<IPermuteFunction> &pre_fn); + void execute(const std::vector<backend::IPortableTensor *> &inputs, + const std::vector<backend::IPortableTensor *> &outputs, + const ExecutionOptions &options) override; - void execute(const IODescription &desc) final; + uint32_t inputSize() const override { return _input_tensors.size(); } - // Used only in Dataflow and Parallel Executors - void setIndexedRanks(std::shared_ptr<ir::OperationIndexMap<int64_t>> ranks) final + uint32_t outputSize() const override { return _output_tensors.size(); } + + const ir::OperandInfo &inputInfo(uint32_t index) const override { - _indexed_ranks = std::move(ranks); - }; + return _input_tensors[index]->get_info(); + } - virtual void executeImpl(void) = 0; + const ir::OperandInfo &outputInfo(uint32_t index) const override + { + return _output_tensors[index]->get_info(); + } - void addObserver(std::unique_ptr<IExecutionObserver> ref) { _subject.add(std::move(ref)); }; + ir::Layout inputLayout(uint32_t index) const override { return _input_tensors[index]->layout(); } - const std::vector<std::shared_ptr<backend::ITensor>> &getInputTensors() const + ir::Layout outputLayout(uint32_t index) const override { - return _input_tensors; + return _output_tensors[index]->layout(); } - const std::vector<std::shared_ptr<backend::ITensor>> &getOutputTensors() const + // Used only in Dataflow and Parallel Executors + void setIndexedRanks(std::shared_ptr<ir::OperationIndexMap<int64_t>> ranks) final { - return _output_tensors; - } + _indexed_ranks = std::move(ranks); + }; - const DynAllocInfoMap &getInputsDynamicAllocInfo() const { return _input_to_dyn_alloc_info; } + virtual void executeImpl(const ExecutionObservee &subject) = 0; + + void addObserver(std::unique_ptr<IExecutionObserver> ref) { _observers.add(std::move(ref)); }; + + backend::BackendContexts &getBackendContexts() { return _backend_contexts; } + + const ExecutionOptions ¤tOptions() const override { return _current_options; } protected: /** @@ -100,19 +99,23 @@ protected: bool hasDynamicInput(); protected: - ExecutionObservee _subject; + ExecObservers _observers; std::shared_ptr<ir::OperationIndexMap<int64_t>> _indexed_ranks; std::unique_ptr<compiler::LoweredGraph> _lowered_graph; + backend::BackendContexts _backend_contexts; const ir::Graph &_graph; - std::vector<std::shared_ptr<backend::ITensor>> _input_tensors; - std::vector<std::shared_ptr<backend::ITensor>> _output_tensors; - DynAllocInfoMap _input_to_dyn_alloc_info; - DynAllocInfoMap _output_to_dyn_alloc_info; - backend::TensorManagerSet _tensor_mgrs; + std::vector<backend::builtin::IOTensor *> _input_tensors; + std::vector<backend::builtin::IOTensor *> _output_tensors; std::mutex _mutex; - -private: - void handleDynamicInputTensor(ir::IOIndex input_index, const IODescription &desc); + const util::TracingCtx *_tracing_ctx; + /** + * It is set by execute() method only in thread-safe environment. + * It is used for non-primary executor call on builtin backend + * and accessed by entryExecutor's currentOptions() method. + * + * TODO: Find better way to pass config to non-primary executor + */ + ExecutionOptions _current_options; }; } // namespace exec diff --git a/runtime/onert/core/src/exec/FunctionSequence.cc b/runtime/onert/core/src/exec/FunctionSequence.cc index fb31f7582..578123a54 100644 --- a/runtime/onert/core/src/exec/FunctionSequence.cc +++ b/runtime/onert/core/src/exec/FunctionSequence.cc @@ -16,8 +16,6 @@ #include "exec/FunctionSequence.h" -#include "ir/Operation.h" -#include "backend/IDynamicTensorManager.h" #include "backend/ITensorRegistry.h" #include "util/logging.h" @@ -28,19 +26,19 @@ namespace exec void FunctionSequence::run() { - // TODO Find out when `_enable_dynamic_shape_inferer` is true but `_dynamic_tensor_ctx` is false if (_enable_dynamic_shape_inferer && _dynamic_tensor_ctx) { - if (_dynamic_tensor_ctx->op_seq->size() != _functions.size()) - throw std::runtime_error("operation and functions should be mapped one by one"); + // acl_cl and acl_neon backend don't support dynamic shape. + // _dynamic_tensor_ctx is always nullptr for acl_cl and acl_neon + // Thus, those two bakends cannot reach here. + + // Do dynamic shape inference + _dynamic_tensor_ctx->op->accept(*_dynamic_tensor_ctx->dynamic_shape_inferer); - auto op_seq_iter = _dynamic_tensor_ctx->op_seq->begin(); for (const auto &function : _functions) { - // set shape of output and allocate memory when needed - auto &op = _dynamic_tensor_ctx->operations->at(*op_seq_iter); - op.accept(*_dynamic_tensor_ctx->dynamic_shape_inferer); - + // NOTE the function could be also FunctionSequence so we do this + // TODO Remove this or do this recursively auto *sub_func_seq = dynamic_cast<FunctionSequence *>(function.get()); if (sub_func_seq != nullptr) { @@ -50,22 +48,12 @@ void FunctionSequence::run() // run kernel function->run(); - - // deallocate input tensors which is no longer used - _dynamic_tensor_ctx->dynamic_tensor_manager->deallocInput(*op_seq_iter); - - op_seq_iter++; } } else { for (const auto &function : _functions) { - auto *sub_func_seq = dynamic_cast<FunctionSequence *>(function.get()); - if (sub_func_seq != nullptr) - { - sub_func_seq->enableDynamicShapeInferer(false); - } function->run(); } } diff --git a/runtime/onert/core/src/exec/IPermuteFunction.cc b/runtime/onert/core/src/exec/IPermuteFunction.cc new file mode 100644 index 000000000..9d548e6dc --- /dev/null +++ b/runtime/onert/core/src/exec/IPermuteFunction.cc @@ -0,0 +1,320 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "IPermuteFunction.h" + +#include <cker/operation/Quantize.h> +#include <cker/operation/Dequantize.h> +#include "backend/IPortableTensor.h" +#include "exec/IFunction.h" +#include "ir/Index.h" +#include "ir/Shape.h" +#include <memory> +#include <misc/polymorphic_downcast.h> +#include <typeinfo> +#include "util/Utils.h" +#include <vector> +#include <unordered_map> + +namespace +{ +using namespace onert; + +inline nnfw::cker::Shape getShape(const backend::ITensor *tensor) +{ + const ir::Shape shape = tensor->getShape(); + + assert(tensor->layout() == ir::Layout::NHWC); + + auto rank = shape.rank(); + nnfw::cker::Shape ret(rank); + auto data = ret.DimsData(); + for (int i = 0; i < rank; ++i) + { + data[i] = shape.dim(i); + } + return ret; +} + +// Quantize per element +template <typename InputT, typename OutputT> +void elementwiseQuantize(const backend::ITensor *src_tensor, backend::ITensor *dst_tensor) +{ + const auto scale = dst_tensor->data_scale(); + const auto zero_point = dst_tensor->data_zero_point(); + + int min_val = std::numeric_limits<OutputT>::min(); + int max_val = std::numeric_limits<OutputT>::max(); + + auto loop_shape = src_tensor->getShape(); + const auto src_layout = src_tensor->layout(); + const auto dst_layout = dst_tensor->layout(); + const bool is_permutation = src_layout != dst_layout && loop_shape.rank() == 4; + ShapeLoop(loop_shape, [&](const onert::ir::Coordinates &coords) { + const InputT *input_data = + reinterpret_cast<const InputT *>(src_tensor->buffer() + src_tensor->calcOffset(coords)); + int32_t unclamped = static_cast<int32_t>(round(*input_data / scale)) + zero_point; + int32_t clamped = std::min(std::max(unclamped, min_val), max_val); + + ir::Coordinates dst_coords = + is_permutation ? ir::convertCoordinates(coords, src_layout, dst_layout) : coords; + OutputT *output_data = + reinterpret_cast<OutputT *>(dst_tensor->buffer() + dst_tensor->calcOffset(dst_coords)); + *output_data = clamped; + }); +} + +// TODO Optimize the case where tensors has the same layout +template <typename InputT, typename OutputT> +void quantize(const backend::ITensor *src_tensor, backend::ITensor *dst_tensor) +{ + if (!src_tensor->has_padding() && !dst_tensor->has_padding() && + src_tensor->layout() == dst_tensor->layout() && !src_tensor->is_dynamic()) + { + assert(!dst_tensor->is_dynamic()); + + // Call optimized neon kernel + nnfw::cker::Quantize(getShape(src_tensor), + reinterpret_cast<const InputT *>(src_tensor->buffer()), + getShape(dst_tensor), reinterpret_cast<OutputT *>(dst_tensor->buffer()), + dst_tensor->data_scale(), dst_tensor->data_zero_point()); + } + else + { + elementwiseQuantize<InputT, OutputT>(src_tensor, dst_tensor); + } +} + +// Dequantize per element +template <typename InputT, typename OutputT> +void elementwiseDequantize(const backend::ITensor *src_tensor, backend::ITensor *dst_tensor) +{ + const auto scale = src_tensor->data_scale(); + const auto zero_point = src_tensor->data_zero_point(); + + auto loop_shape = src_tensor->getShape(); + const auto src_layout = src_tensor->layout(); + const auto dst_layout = dst_tensor->layout(); + const bool is_permutation = src_layout != dst_layout && loop_shape.rank() == 4; + ShapeLoop(loop_shape, [&](const onert::ir::Coordinates &coords) { + const InputT *input_data = + reinterpret_cast<const InputT *>(src_tensor->buffer() + src_tensor->calcOffset(coords)); + const OutputT result = static_cast<OutputT>(scale * (*input_data - zero_point)); + + ir::Coordinates dst_coords = + is_permutation ? ir::convertCoordinates(coords, src_layout, dst_layout) : coords; + OutputT *output_data = + reinterpret_cast<OutputT *>(dst_tensor->buffer() + dst_tensor->calcOffset(dst_coords)); + *output_data = result; + }); +} + +// TODO Optimize the case where tensors has the same layout +template <typename InputT, typename OutputT> +void dequantize(const backend::ITensor *src_tensor, backend::ITensor *dst_tensor) +{ + if (!src_tensor->has_padding() && !dst_tensor->has_padding() && + src_tensor->layout() == dst_tensor->layout() && !src_tensor->is_dynamic()) + { + assert(!dst_tensor->is_dynamic()); + + // Call optimized neon kernel + nnfw::cker::Dequantize(getShape(src_tensor), + reinterpret_cast<const InputT *>(src_tensor->buffer()), + getShape(dst_tensor), reinterpret_cast<OutputT *>(dst_tensor->buffer()), + src_tensor->data_scale(), src_tensor->data_zero_point()); + } + else + { + elementwiseDequantize<InputT, OutputT>(src_tensor, dst_tensor); + } +} + +template <typename SRC_T, typename DST_T, + std::enable_if_t<std::is_base_of<backend::ITensor, SRC_T>::value && + std::is_base_of<backend::ITensor, DST_T>::value, + bool> = true> +void typeAwareQuantize(const SRC_T *src_tensor, DST_T *dst_tensor) +{ + // TODO Support other types + if (src_tensor->data_type() == ir::DataType::FLOAT32) + { + switch (dst_tensor->data_type()) + { + case ir::DataType::QUANT_UINT8_ASYMM: + { + quantize<float, uint8_t>(src_tensor, dst_tensor); + break; + } + case ir::DataType::QUANT_INT8_SYMM: + { + quantize<float, int8_t>(src_tensor, dst_tensor); + break; + } + case ir::DataType::QUANT_INT16_SYMM: + { + quantize<float, int16_t>(src_tensor, dst_tensor); + break; + } + default: + { + throw std::runtime_error("IPermuteFunction: Unsupported quantization type"); + break; + } + } + } + else if (dst_tensor->data_type() == ir::DataType::FLOAT32) + { + switch (src_tensor->data_type()) + { + case ir::DataType::QUANT_UINT8_ASYMM: + { + dequantize<uint8_t, float>(src_tensor, dst_tensor); + break; + } + case ir::DataType::QUANT_INT8_SYMM: + { + dequantize<int8_t, float>(src_tensor, dst_tensor); + break; + } + case ir::DataType::QUANT_INT16_SYMM: + { + dequantize<int16_t, float>(src_tensor, dst_tensor); + break; + } + default: + { + throw std::runtime_error("IPermuteFunction: Unsupported dequantization type"); + break; + } + } + } + else + { + throw std::runtime_error("IPermuteFunction: Unsupported type for type-aware quantization yet"); + } +} + +} // namespace + +namespace onert +{ +namespace exec +{ + +void IPermuteFunction::IPermuteFunction::run() +{ + // TODO Optimization : Make control does not reach here? when (_src_tensors.size() == 0) + assert(_src_tensors.size() == _dst_tensors.size()); + if (_src_tensors_offsets.size() == 0) + { + _src_tensors_offsets.resize(_src_tensors.size()); + _dst_tensors_offsets.resize(_dst_tensors.size()); + } + assert(_src_tensors.size() == _src_tensors_offsets.size()); + assert(_src_tensors_offsets.size() == _dst_tensors_offsets.size()); + + for (size_t i = 0; i < _src_tensors.size(); ++i) + { + auto src_tensor = _src_tensors.at(i); + auto dst_tensor = _dst_tensors.at(i); + auto &src_offsets = _src_tensors_offsets.at(i); + auto &dst_offsets = _dst_tensors_offsets.at(i); + if (src_tensor != dst_tensor) + { + const auto rank = src_tensor->getShape().rank(); + permute(src_tensor, dst_tensor, rank, src_offsets, dst_offsets); + } + } +} + +void IPermuteFunction::permute(backend::ITensor *src_tensor, backend::ITensor *dst_tensor, + size_t rank, std::vector<size_t> &src_offsets, + std::vector<size_t> &dst_offsets) +{ + if (src_tensor->total_size() == 0) + { + assert(dst_tensor->total_size() == 0); + return; + } + + assert(src_tensor != dst_tensor); + if (underlying_type(src_tensor->data_type()) != underlying_type(dst_tensor->data_type())) + { + typeAwareQuantize(src_tensor, dst_tensor); + return; + } + + switch (src_tensor->data_type()) + { + case ir::DataType::FLOAT32: + permute<float>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets); + break; + case ir::DataType::INT32: + permute<int32_t>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets); + break; + case ir::DataType::UINT32: + permute<uint32_t>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets); + break; + case ir::DataType::BOOL8: + case ir::DataType::QUANT_UINT8_ASYMM: + case ir::DataType::UINT8: + permute<uint8_t>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets); + break; + case ir::DataType::QUANT_INT8_ASYMM: + case ir::DataType::QUANT_INT8_SYMM: + permute<int8_t>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets); + break; + case ir::DataType::INT64: + permute<int64_t>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets); + break; + case ir::DataType::QUANT_INT16_SYMM: + permute<int16_t>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets); + break; + default: + throw std::runtime_error("IPermuteFunction: Not supported data type"); + break; + } +} + +const std::type_info &IPermuteFunction::underlying_type(ir::DataType type) const +{ + switch (type) + { + case ir::DataType::FLOAT32: + return typeid(float); + case ir::DataType::INT32: + return typeid(int32_t); + case ir::DataType::UINT32: + return typeid(uint32_t); + case ir::DataType::INT64: + return typeid(int64_t); + case ir::DataType::BOOL8: + case ir::DataType::QUANT_UINT8_ASYMM: + case ir::DataType::UINT8: + return typeid(uint8_t); + case ir::DataType::QUANT_INT8_ASYMM: + case ir::DataType::QUANT_INT8_SYMM: + return typeid(int8_t); + case ir::DataType::QUANT_INT16_SYMM: + return typeid(int16_t); + default: + throw std::runtime_error("IPermuteFunction: Not supported data type"); + } +} + +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/IPermuteFunction.h b/runtime/onert/core/src/exec/IPermuteFunction.h index 6b4d15380..ccac66cad 100644 --- a/runtime/onert/core/src/exec/IPermuteFunction.h +++ b/runtime/onert/core/src/exec/IPermuteFunction.h @@ -25,21 +25,48 @@ #include "backend/ITensor.h" #include "exec/IFunction.h" -#include "ir/Index.h" -#include "ir/Shape.h" #include <memory> -#include <typeinfo> -#include "util/Utils.h" #include <vector> +#include <unordered_map> namespace onert { namespace exec { +inline void UpdateOffsets(::onert::backend::ITensor *src, ::onert::backend::ITensor *dst, + const ::onert::ir::Shape &loop_shape, std::vector<size_t> &src_offsets, + std::vector<size_t> &dst_offsets) +{ + ShapeLoop(loop_shape, [&](const onert::ir::Coordinates &coords) { + src_offsets.emplace_back(src->calcOffset(coords)); + dst_offsets.emplace_back(dst->calcOffset(coords)); + }); +} + +inline void CopyStatic(const uint8_t *src_buffer, uint8_t *dst_buffer, + const std::vector<size_t> &src_offsets, + const std::vector<size_t> &dst_offsets, size_t copy_len) +{ + assert(src_offsets.size() == dst_offsets.size()); + for (size_t i = 0; i < src_offsets.size(); ++i) + { + memcpy(dst_buffer + dst_offsets.at(i), src_buffer + src_offsets.at(i), copy_len); + } +} + +inline void CopyDynamic(const ::onert::backend::ITensor *src, const ::onert::backend::ITensor *dst, + uint8_t *dst_buffer, const ::onert::ir::Shape &loop_shape, size_t copy_len) +{ + ShapeLoop(loop_shape, [&](const onert::ir::Coordinates &coords) { + // Copy src tensor's data to dst_buffer with calculated offset of dst tensor + memcpy(dst_buffer + dst->calcOffset(coords), src->buffer() + src->calcOffset(coords), copy_len); + }); +} + class IPermuteFunction : public IFunction { -private: +protected: enum class PermuteType { NHWC_TO_NCHW, @@ -48,63 +75,69 @@ private: }; public: - virtual void run() override + virtual void run() override; + + virtual void prepare() override { optimize(); } + + virtual void optimize() = 0; + +protected: + void permute(backend::ITensor *src_tensor, backend::ITensor *dst_tensor, size_t rank, + std::vector<size_t> &src_offsets, std::vector<size_t> &dst_offsets); + +private: + // TODO make src const by proving const access() + template <class T> + void permute(backend::ITensor *src, backend::ITensor *dst, size_t rank, + std::vector<size_t> &src_offsets, std::vector<size_t> &dst_offsets) { - assert(_src_tensors.size() > 0); - assert(_src_tensors.size() == _dst_tensors.size()); - auto src_it = _src_tensors.begin(); - auto dst_it = _dst_tensors.begin(); - while (src_it != _src_tensors.end()) + assert(src->total_size() != 0 && dst->total_size() != 0); + // If dst is subtensor, we have to use clEnqueueMapBuffer instead of clEnqueueWirteBuffer + if (dst->needMemoryMap() && !dst->is_subtensor()) { - const auto src_tensor = *src_it; - auto dst_tensor = *dst_it; - if (src_tensor != dst_tensor) + // A assertion to check mapping without calling map() + // Now there is no case where both src and dst have cl buffer. + assert(!src->needMemoryMap()); + + if (!src->has_padding() && !dst->has_padding() && src->layout() == dst->layout()) { - // TODO Change to permute in parallel - assert(underlying_type(src_tensor->data_type()) == - underlying_type(dst_tensor->data_type())); - const auto rank = src_tensor->num_dimensions(); - switch (src_tensor->data_type()) - { - case ir::DataType::FLOAT32: - permute<float>(src_tensor, dst_tensor, rank); - break; - case ir::DataType::INT32: - permute<int32_t>(src_tensor, dst_tensor, rank); - break; - case ir::DataType::UINT32: - permute<uint32_t>(src_tensor, dst_tensor, rank); - break; - case ir::DataType::BOOL8: - case ir::DataType::QUANT_UINT8_ASYMM: - case ir::DataType::UINT8: - permute<uint8_t>(src_tensor, dst_tensor, rank); - break; - case ir::DataType::QUANT_INT8_SYMM: - permute<int8_t>(src_tensor, dst_tensor, rank); - break; - case ir::DataType::INT64: - permute<int64_t>(src_tensor, dst_tensor, rank); - break; - default: - throw std::runtime_error("IPermuteFunction: Not supported data type"); - break; - } + src->access([&](backend::ITensor &) { dst->enqueueWriteBuffer(src->buffer(), false); }); + } + else + { + // TODO Optimize this block in case of that padding size of dst is big. + _buffers_map[dst].reserve(dst->total_size()); + auto dst_buffer = _buffers_map[dst].data(); + src->access([&](backend::ITensor &) { + permute<T>(src, dst, rank, dst_buffer, dst->total_size(), src_offsets, dst_offsets); + }); + dst->enqueueWriteBuffer(dst_buffer, false); } - src_it++; - dst_it++; + } + else if (src->needMemoryMap() && !src->is_subtensor() && !src->has_padding() && + !dst->has_padding() && src->layout() == dst->layout()) + { + assert(!dst->needMemoryMap()); + dst->access([&](backend::ITensor &) { src->enqueueReadBuffer(dst->buffer(), true); }); + } + else + { + auto fn = [&](backend::ITensor &) { + dst->access([&](backend::ITensor &) { + permute<T>(src, dst, rank, dst->buffer(), dst->total_size(), src_offsets, dst_offsets); + }); + }; + src->access(fn); } } - virtual void prepare() override { optimize(); } - - virtual void optimize() = 0; - -private: template <class T> - void permute(const std::shared_ptr<backend::ITensor> &src, std::shared_ptr<backend::ITensor> &dst, - size_t rank) + void permute(backend::ITensor *src, backend::ITensor *dst, size_t rank, uint8_t *dst_buffer, + size_t dst_size, std::vector<size_t> &src_offsets, std::vector<size_t> &dst_offsets) { + assert(dst_buffer != nullptr); + assert(dst_size == dst->total_size()); + const auto permute_type = [&]() -> PermuteType { if (src->layout() == ir::Layout::NHWC && dst->layout() == ir::Layout::NCHW) { @@ -119,166 +152,130 @@ private: return PermuteType::COPY; } }(); - auto fn = [&](backend::ITensor &src_tensor) { - dst->access([&](backend::ITensor &dst_tensor) { - auto src_buffer = src_tensor.buffer(); - auto src_size = src_tensor.total_size(); - auto dst_buffer = dst_tensor.buffer(); - if (permute_type == PermuteType::COPY) + if (rank == 4 && permute_type != PermuteType::COPY) + { + switch (permute_type) + { + case PermuteType::NHWC_TO_NCHW: { - assert(src_tensor.layout() == dst_tensor.layout()); - if (!src_tensor.has_padding() && !dst_tensor.has_padding()) - { - assert(src_size <= dst_tensor.total_size()); - memcpy(dst_buffer, src_buffer, src_size); - return; - } + ir::FeatureShape shape; + auto dst_shape = dst->getShape(); + shape.N = dst_shape.dim(0); + shape.C = dst_shape.dim(1); + shape.H = dst_shape.dim(2); + shape.W = dst_shape.dim(3); + + typename feature::nchw::View<T>::Strides strides; + const auto start_offset = dst->calcOffset({0, 0, 0, 0}); + strides.W = dst_shape.dim(3) == 1 ? 0 : dst->calcOffset({0, 0, 0, 1}) - start_offset; + strides.H = dst_shape.dim(2) == 1 ? 0 : dst->calcOffset({0, 0, 1, 0}) - start_offset; + strides.C = dst_shape.dim(1) == 1 ? 0 : dst->calcOffset({0, 1, 0, 0}) - start_offset; + strides.N = dst_shape.dim(0) == 1 ? 0 : dst->calcOffset({1, 0, 0, 0}) - start_offset; + + const feature::nhwc::Reader<T> from(src); + feature::nchw::View<T> into(shape, strides, + reinterpret_cast<T *>(dst_buffer + start_offset), dst_size); + feature::iterate(shape) << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) { + const auto value = from.at(batch, row, col, ch); + into.at(batch, ch, row, col) = value; + }; + break; } - switch (rank) + case PermuteType::NCHW_TO_NHWC: { - case 0: - case 1: - { - const int32_t copy_len = dst_tensor.dimension(0); + ir::FeatureShape shape; + auto dst_shape = dst->getShape(); + shape.N = dst_shape.dim(0); + shape.H = dst_shape.dim(1); + shape.W = dst_shape.dim(2); + shape.C = dst_shape.dim(3); - memcpy(dst_buffer, src_buffer, copy_len * sizeof(T)); - break; - } - case 2: - { - const int32_t dim_0 = dst_tensor.dimension(0); - const int32_t copy_len = dst_tensor.dimension(1); + typename feature::nhwc::View<T>::Strides strides; + const auto start_offset = dst->calcOffset({0, 0, 0, 0}); + strides.C = dst_shape.dim(3) == 1 ? 0 : dst->calcOffset({0, 0, 0, 1}) - start_offset; + strides.W = dst_shape.dim(2) == 1 ? 0 : dst->calcOffset({0, 0, 1, 0}) - start_offset; + strides.H = dst_shape.dim(1) == 1 ? 0 : dst->calcOffset({0, 1, 0, 0}) - start_offset; + strides.N = dst_shape.dim(0) == 1 ? 0 : dst->calcOffset({1, 0, 0, 0}) - start_offset; - for (int32_t i = 0; i < dim_0; ++i) - { - ir::Coordinates coords{i, 0}; - memcpy(dst_buffer + dst_tensor.calcOffset(coords), - src_buffer + src_tensor.calcOffset(coords), copy_len * sizeof(T)); - } - break; - } - case 3: - { - const int32_t dim_0 = dst_tensor.dimension(0); - const int32_t dim_1 = dst_tensor.dimension(1); - const int32_t copy_len = dst_tensor.dimension(2); + const feature::nchw::Reader<T> from(src); + feature::nhwc::View<T> into(shape, strides, + reinterpret_cast<T *>(dst_buffer + start_offset), dst_size); + feature::iterate(shape) << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) { + const auto value = from.at(batch, ch, row, col); + into.at(batch, row, col, ch) = value; + }; + break; + } + default: + { + throw std::runtime_error("Unsupported Permutation"); + break; + } + } + } + else if (!src->has_padding() && !dst->has_padding()) + { + auto src_size = src->total_size(); + assert(src_size <= dst->total_size()); + memcpy(dst_buffer, src->buffer(), src_size); + } + else + { + auto loop_shape = src->getShape(); + const auto copy_axis = loop_shape.rank() - 1; + const auto copy_len = loop_shape.dim(copy_axis) * sizeof(T); + loop_shape.dim(copy_axis) = 1; - for (auto i = 0; i < dim_0; ++i) - { - for (auto j = 0; j < dim_1; ++j) - { - ir::Coordinates coords{i, j, 0}; - memcpy(dst_buffer + dst_tensor.calcOffset(coords), - src_buffer + src_tensor.calcOffset(coords), copy_len * sizeof(T)); - } - } - break; - } - case 4: - { - switch (permute_type) - { - case PermuteType::NHWC_TO_NCHW: - { - ir::FeatureShape shape; - shape.N = dst_tensor.dimension(0); - shape.C = dst_tensor.dimension(1); - shape.H = dst_tensor.dimension(2); - shape.W = dst_tensor.dimension(3); - const feature::nhwc::Reader<T> from(&src_tensor); - feature::nchw::View<T> into(&dst_tensor); - feature::iterate(shape) - << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) { - const auto value = from.at(batch, row, col, ch); - into.at(batch, ch, row, col) = value; - }; - break; - } - case PermuteType::NCHW_TO_NHWC: - { - ir::FeatureShape shape; - shape.N = src_tensor.dimension(0); - shape.C = src_tensor.dimension(1); - shape.H = src_tensor.dimension(2); - shape.W = src_tensor.dimension(3); - const feature::nchw::Reader<T> from(&src_tensor); - feature::nhwc::View<T> into(&dst_tensor); - feature::iterate(shape) - << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) { - const auto value = from.at(batch, ch, row, col); - into.at(batch, row, col, ch) = value; - }; - break; - } - case PermuteType::COPY: - { - const int32_t dim_0 = dst_tensor.dimension(0); - const int32_t dim_1 = dst_tensor.dimension(1); - const int32_t dim_2 = dst_tensor.dimension(2); - const int32_t copy_len = dst_tensor.dimension(3); + if (src->is_dynamic()) + { + assert(dst->is_dynamic()); + CopyDynamic(src, dst, dst_buffer, loop_shape, copy_len); + } + else + { + // TODO Uncomment the assertion below + // assert(!dst->is_dynamic() || dst is output of graph); + if (src_offsets.size() == 0) + { + assert(dst_offsets.size() == 0); - for (auto i = 0; i < dim_0; ++i) - { - for (auto j = 0; j < dim_1; ++j) - { - for (auto k = 0; k < dim_2; ++k) - { - ir::Coordinates coords{i, j, k, 0}; - memcpy(dst_buffer + dst_tensor.calcOffset(coords), - src_buffer + src_tensor.calcOffset(coords), copy_len * sizeof(T)); - } - } - } - break; - } - default: - { - throw std::runtime_error("Unsupported Permutation"); - break; - } - } - break; - } - default: - throw std::runtime_error("Unsupported rank in permutation"); - break; + auto loop_shape = src->getShape(); + const auto copy_axis = loop_shape.rank() - 1; + loop_shape.dim(copy_axis) = 1; + UpdateOffsets(src, dst, loop_shape, src_offsets, dst_offsets); } - }); - }; - src->access(fn); + CopyStatic(src->buffer(), dst_buffer, src_offsets, dst_offsets, copy_len); + } + } } +protected: // NOTE The typeid expression is lvalue expression which refers to an object with static storage // duration, of the polymorphic type const std::type_info or of some type derived from it. // So std::type_info is non-copyable - const std::type_info &underlying_type(ir::DataType type) const - { - switch (type) - { - case ir::DataType::FLOAT32: - return typeid(float); - case ir::DataType::INT32: - return typeid(int32_t); - case ir::DataType::UINT32: - return typeid(uint32_t); - case ir::DataType::INT64: - return typeid(int64_t); - case ir::DataType::BOOL8: - case ir::DataType::QUANT_UINT8_ASYMM: - case ir::DataType::UINT8: - return typeid(uint8_t); - case ir::DataType::QUANT_INT8_SYMM: - return typeid(int8_t); - default: - throw std::runtime_error("IPermuteFunction: Not supported data type"); - } - } + const std::type_info &underlying_type(ir::DataType type) const; protected: - std::vector<std::shared_ptr<backend::ITensor>> _src_tensors; - std::vector<std::shared_ptr<backend::ITensor>> _dst_tensors; - // TODO Remove this member if it is possible - std::vector<size_t> _ranks; + std::vector<backend::ITensor *> _src_tensors; + std::vector<backend::ITensor *> _dst_tensors; + std::vector<std::vector<size_t>> _src_tensors_offsets; + std::vector<std::vector<size_t>> _dst_tensors_offsets; + std::unordered_map<const backend::ITensor *, std::vector<uint8_t>> _buffers_map; +}; + +// Simple PermuteLayer +class PermuteLayer : public onert::exec::IPermuteFunction +{ +public: + PermuteLayer(const std::vector<onert::backend::ITensor *> &inputs, + const std::vector<onert::backend::ITensor *> &outputs) + { + assert(inputs.size() == outputs.size()); + _src_tensors = inputs; + _dst_tensors = outputs; + } + virtual ~PermuteLayer() {} + void optimize() override {} }; } // namespace exec diff --git a/runtime/onert/core/src/exec/IPermuteFunction.test.cc b/runtime/onert/core/src/exec/IPermuteFunction.test.cc new file mode 100644 index 000000000..fb2dd3b95 --- /dev/null +++ b/runtime/onert/core/src/exec/IPermuteFunction.test.cc @@ -0,0 +1,920 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "IPermuteFunction.h" + +#include <ir/Layout.h> +#include <ir/Shape.h> +#include <ir/TypeInfo.h> + +#include <cmath> +#include <gtest/gtest.h> + +namespace +{ +using namespace onert; +using namespace ir; +using namespace backend; +using namespace exec; + +class MockUpTensor : public ITensor +{ +public: + MockUpTensor(const Shape &shape, const TypeInfo &type_info, Layout layout, size_t pad) + : _shape(shape), _type_info(type_info), _data(nullptr), _layout(layout) + { + _strides.resize(shape.rank()); + + std::vector<size_t> pads(shape.rank(), 0); + pads[shape.rank() - 1] = pad; + size_t stride = 1; + for (int32_t i = _shape.rank() - 1; i >= 0; --i) + { + _strides.at(i) = stride; + stride = stride * (_shape.dim(i) + pads.at(i)); + } + } + virtual ~MockUpTensor() {} + + void setBuffer(uint8_t *data) { _data = data; } + + size_t total_size() const override + { + size_t total_size = _strides[0] * _shape.dim(0); + total_size *= sizeOfDataType(data_type()); + return total_size; + } + + size_t calcOffset(const ir::Coordinates &coords) const override + { + size_t offset = 0; + for (size_t i = 0; i < _shape.rank(); ++i) + { + offset += (_strides[i] * coords[i]); + } + offset *= sizeOfDataType(data_type()); + return offset; + } + + uint8_t *buffer() const override { return _data; } + + ir::Layout layout() const override { return _layout; } + ir::DataType data_type() const override { return _type_info.type(); } + float data_scale() const override { return _type_info.scale(); } + int32_t data_zero_point() const override { return _type_info.zero_point(); } + const std::vector<float> &data_scales() const override { return _type_info.scales(); } + const std::vector<int32_t> &data_zero_points() const override { return _type_info.zero_points(); } + bool has_padding() const override + { + return total_size() / sizeOfDataType(data_type()) != _shape.num_elements(); + } + void access(const std::function<void(ITensor &tensor)> &fn) final { fn(*this); } + + bool is_dynamic() const override { return false; } + Shape getShape() const override { return _shape; } + +private: + Shape _shape; + TypeInfo _type_info; + Layout _layout; + uint8_t *_data; + std::vector<size_t> _strides; +}; + +class MockUpLayer : public IPermuteFunction +{ +public: + MockUpLayer(const std::vector<ITensor *> &inputs, const std::vector<ITensor *> &outputs) + { + assert(inputs.size() == outputs.size()); + _src_tensors = inputs; + _dst_tensors = outputs; + } + virtual ~MockUpLayer() {} + void optimize() override {} +}; + +TEST(IPermuteFunction, float_to_float) +{ + // rank 1 + { + const size_t input_pads[4] = {0, 1, 0, 2}; + const size_t output_pads[4] = {0, 0, 2, 1}; + const std::vector<Shape> shapes{{1}, {4}, {5}, {2}}; + float expected_buffer[] = {1, 0, -1, -2, 3}; + const auto type_info = TypeInfo(DataType::FLOAT32); + + std::vector<std::unique_ptr<MockUpTensor>> inputs(4); + std::vector<std::unique_ptr<MockUpTensor>> outputs(4); + + std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4); + for (size_t i = 0; i < 4; ++i) + { + inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]); + inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer)); + + outputs[i] = + std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]); + output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size()); + outputs[i]->setBuffer(output_buffers[i].get()); + } + + auto mockup_layer = std::make_unique<MockUpLayer>( + std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()}, + std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), + outputs[3].get()}); + mockup_layer->run(); + + for (size_t i = 0; i < 4; ++i) + { + for (int32_t j = 0; j < shapes[i].dim(0); ++j) + { + Coordinates coords{j}; + float result = + *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords)); + float expected = + *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords)); + EXPECT_EQ(result, expected); + } + } + } + + // rank 2 + { + const size_t input_pads[4] = {0, 1, 0, 2}; + const size_t output_pads[4] = {0, 0, 2, 1}; + const std::vector<Shape> shapes{{1, 4}, {2, 2}, {1, 5}, {2, 3}}; + float expected_buffer[] = {1, 0, -1, -2, 3, -4, 5, -6, 7, -8}; + const auto type_info = TypeInfo(DataType::FLOAT32); + + std::vector<std::unique_ptr<MockUpTensor>> inputs(4); + std::vector<std::unique_ptr<MockUpTensor>> outputs(4); + std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4); + for (size_t i = 0; i < 4; ++i) + { + inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]); + inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer)); + + outputs[i] = + std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]); + output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size()); + outputs[i]->setBuffer(output_buffers[i].get()); + } + + auto mockup_layer = std::make_unique<MockUpLayer>( + std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()}, + std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), + outputs[3].get()}); + mockup_layer->run(); + + for (size_t i = 0; i < 4; ++i) + { + for (int32_t j = 0; j < shapes[i].dim(0); ++j) + { + for (int32_t k = 0; k < shapes[i].dim(1); ++k) + { + Coordinates coords{j, k}; + float result = + *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords)); + float expected = + *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords)); + EXPECT_EQ(result, expected); + } + } + } + } + + // rank 3 + { + const size_t input_pads[4] = {0, 5, 0, 2}; + const size_t output_pads[4] = {0, 3, 2, 1}; + const std::vector<Shape> shapes{{1, 4, 1}, {1, 2, 1}, {2, 1, 5}, {1, 2, 3}}; + float expected_buffer[] = {1, 0, -1, -2, 3, -4, 5, -6, 7, -8, 9, -10}; + const auto type_info = TypeInfo(DataType::FLOAT32); + + std::vector<std::unique_ptr<MockUpTensor>> inputs(4); + std::vector<std::unique_ptr<MockUpTensor>> outputs(4); + std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4); + for (size_t i = 0; i < 4; ++i) + { + inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]); + inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer)); + + outputs[i] = + std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]); + output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size()); + outputs[i]->setBuffer(output_buffers[i].get()); + } + + auto mockup_layer = std::make_unique<MockUpLayer>( + std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()}, + std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), + outputs[3].get()}); + mockup_layer->run(); + + for (size_t i = 0; i < 4; ++i) + { + for (int32_t j = 0; j < shapes[i].dim(0); ++j) + { + for (int32_t k = 0; k < shapes[i].dim(1); ++k) + { + for (int32_t l = 0; l < shapes[i].dim(2); ++l) + { + Coordinates coords{j, k, l}; + float result = + *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords)); + float expected = + *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords)); + EXPECT_EQ(result, expected); + } + } + } + } + } + + // rank 4 + { + const size_t input_pads[4] = {0, 0, 1, 2}; + const size_t output_pads[4] = {0, 3, 2, 1}; + const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}}; + float expected_buffer[] = {1, 0, -1, -2, 3, -4, 5, -6, 7, -8, 9, -10}; + const auto type_info = TypeInfo(DataType::FLOAT32); + + std::vector<std::unique_ptr<MockUpTensor>> inputs(4); + std::vector<std::unique_ptr<MockUpTensor>> outputs(4); + std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4); + for (size_t i = 0; i < 4; ++i) + { + inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]); + inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer)); + + outputs[i] = + std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]); + output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size()); + outputs[i]->setBuffer(output_buffers[i].get()); + } + + auto mockup_layer = std::make_unique<MockUpLayer>( + std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()}, + std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), + outputs[3].get()}); + mockup_layer->run(); + + for (size_t i = 0; i < 4; ++i) + { + for (int32_t j = 0; j < shapes[i].dim(0); ++j) + { + for (int32_t k = 0; k < shapes[i].dim(1); ++k) + { + for (int32_t l = 0; l < shapes[i].dim(2); ++l) + { + for (int32_t m = 0; m < shapes[i].dim(3); ++m) + { + Coordinates coords{j, k, l, m}; + float result = + *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords)); + float expected = + *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords)); + EXPECT_EQ(result, expected); + } + } + } + } + } + } + + // rank4 layout + { + const size_t input_pads[4] = {0, 0, 1, 2}; + const size_t output_pads[4] = {0, 3, 2, 1}; + const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}}; + float expected_buffer[] = {1, 0, -1, -2, 3, -4, 5, -6, 7, + -8, 9, -10, 11, -12, 13, -14, 15, -16}; + const auto type_info = TypeInfo(DataType::FLOAT32); + + std::vector<std::unique_ptr<MockUpTensor>> inputs(4); + std::vector<std::unique_ptr<MockUpTensor>> outputs(4); + std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4); + for (size_t i = 0; i < 4; ++i) + { + Layout layout = Layout::NHWC; + Shape shape = shapes[i]; + if (i % 2 == 1) + { + layout = Layout::NCHW; + shape = Shape{shapes[i].dim(0), shapes[i].dim(3), shapes[i].dim(1), shapes[i].dim(2)}; + } + inputs[i] = std::make_unique<MockUpTensor>(shape, type_info, layout, input_pads[i]); + inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer)); + + if (layout == Layout::NHWC) + { + layout = Layout::NCHW; + shape = Shape{shapes[i].dim(0), shapes[i].dim(3), shapes[i].dim(1), shapes[i].dim(2)}; + } + else + { + layout = Layout::NHWC; + shape = shapes[i]; + } + outputs[i] = std::make_unique<MockUpTensor>(shape, type_info, layout, output_pads[i]); + output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size()); + outputs[i]->setBuffer(output_buffers[i].get()); + } + + auto mockup_layer = std::make_unique<MockUpLayer>( + std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()}, + std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), + outputs[3].get()}); + mockup_layer->run(); + + for (size_t i = 0; i < 4; ++i) + { + for (int32_t j = 0; j < shapes[i].dim(0); ++j) + { + for (int32_t k = 0; k < shapes[i].dim(1); ++k) + { + for (int32_t l = 0; l < shapes[i].dim(2); ++l) + { + for (int32_t m = 0; m < shapes[i].dim(3); ++m) + { + Coordinates input_coords; + Coordinates output_coords; + if (inputs[i]->layout() == Layout::NHWC) + { + input_coords = Coordinates{j, k, l, m}; + } + else + { + input_coords = Coordinates{j, m, k, l}; + } + if (outputs[i]->layout() == Layout::NHWC) + { + output_coords = Coordinates{j, k, l, m}; + } + else + { + output_coords = Coordinates{j, m, k, l}; + } + float result = *reinterpret_cast<float *>(outputs[i]->buffer() + + outputs[i]->calcOffset(output_coords)); + float expected = *reinterpret_cast<float *>(inputs[i]->buffer() + + inputs[i]->calcOffset(input_coords)); + EXPECT_EQ(result, expected); + } + } + } + } + } + } +} + +TEST(IPermuteFunction, float_to_qasymm8) +{ + const size_t input_pads[4] = {0, 0, 1, 2}; + const size_t output_pads[4] = {0, 3, 2, 1}; + const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}}; + float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, -80, 90, -100}; + float scale = 10; + int32_t zero_point = 128; + + std::vector<std::unique_ptr<MockUpTensor>> inputs(4); + std::vector<std::unique_ptr<MockUpTensor>> outputs(4); + std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4); + for (size_t i = 0; i < 4; ++i) + { + inputs[i] = std::make_unique<MockUpTensor>(shapes[i], TypeInfo(DataType::FLOAT32), Layout::NHWC, + input_pads[i]); + inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer)); + + TypeInfo type_info{DataType::QUANT_UINT8_ASYMM, scale, zero_point}; + outputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]); + output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size()); + outputs[i]->setBuffer(output_buffers[i].get()); + } + + auto mockup_layer = std::make_unique<MockUpLayer>( + std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()}, + std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), outputs[3].get()}); + mockup_layer->run(); + + for (size_t i = 0; i < 4; ++i) + { + for (int32_t j = 0; j < shapes[i].dim(0); ++j) + { + for (int32_t k = 0; k < shapes[i].dim(1); ++k) + { + for (int32_t l = 0; l < shapes[i].dim(2); ++l) + { + for (int32_t m = 0; m < shapes[i].dim(3); ++m) + { + Coordinates coords{j, k, l, m}; + uint8_t qasymm8 = + *reinterpret_cast<uint8_t *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords)); + float result = (qasymm8 - zero_point) * scale; + float expected = + *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords)); + EXPECT_EQ(result, expected); + } + } + } + } + } +} + +TEST(IPermuteFunction, float_to_qsymm8) +{ + const size_t input_pads[4] = {0, 0, 1, 2}; + const size_t output_pads[4] = {0, 3, 2, 1}; + const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}}; + float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, -80, 90, -100}; + float scale = 10; + int32_t zero_point = 0; + + std::vector<std::unique_ptr<MockUpTensor>> inputs(4); + std::vector<std::unique_ptr<MockUpTensor>> outputs(4); + std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4); + for (size_t i = 0; i < 4; ++i) + { + inputs[i] = std::make_unique<MockUpTensor>(shapes[i], TypeInfo(DataType::FLOAT32), Layout::NHWC, + input_pads[i]); + inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer)); + + TypeInfo type_info{DataType::QUANT_INT8_SYMM, scale, zero_point}; + outputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]); + output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size()); + outputs[i]->setBuffer(output_buffers[i].get()); + } + + auto mockup_layer = std::make_unique<MockUpLayer>( + std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()}, + std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), outputs[3].get()}); + mockup_layer->run(); + + for (size_t i = 0; i < 4; ++i) + { + for (int32_t j = 0; j < shapes[i].dim(0); ++j) + { + for (int32_t k = 0; k < shapes[i].dim(1); ++k) + { + for (int32_t l = 0; l < shapes[i].dim(2); ++l) + { + for (int32_t m = 0; m < shapes[i].dim(3); ++m) + { + Coordinates coords{j, k, l, m}; + int8_t qsymm8 = + *reinterpret_cast<int8_t *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords)); + float result = (qsymm8 - zero_point) * scale; + float expected = + *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords)); + EXPECT_EQ(result, expected); + } + } + } + } + } +} + +TEST(IPermuteFunction, float_to_qsymm16) +{ + const size_t input_pads[4] = {0, 0, 1, 2}; + const size_t output_pads[4] = {0, 3, 2, 1}; + const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}}; + float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, -80, 90, -100}; + float scale = 10; + int32_t zero_point = 0; + + std::vector<std::unique_ptr<MockUpTensor>> inputs(4); + std::vector<std::unique_ptr<MockUpTensor>> outputs(4); + std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4); + for (size_t i = 0; i < 4; ++i) + { + inputs[i] = std::make_unique<MockUpTensor>(shapes[i], TypeInfo(DataType::FLOAT32), Layout::NHWC, + input_pads[i]); + inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer)); + + TypeInfo type_info{DataType::QUANT_INT16_SYMM, scale, zero_point}; + outputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]); + output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size()); + outputs[i]->setBuffer(output_buffers[i].get()); + } + + auto mockup_layer = std::make_unique<MockUpLayer>( + std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()}, + std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), outputs[3].get()}); + mockup_layer->run(); + + for (size_t i = 0; i < 4; ++i) + { + for (int32_t j = 0; j < shapes[i].dim(0); ++j) + { + for (int32_t k = 0; k < shapes[i].dim(1); ++k) + { + for (int32_t l = 0; l < shapes[i].dim(2); ++l) + { + for (int32_t m = 0; m < shapes[i].dim(3); ++m) + { + Coordinates coords{j, k, l, m}; + int16_t qsymm16 = + *reinterpret_cast<int16_t *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords)); + float result = (qsymm16 - zero_point) * scale; + float expected = + *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords)); + EXPECT_EQ(result, expected); + } + } + } + } + } +} + +TEST(IPermuteFunction, qasymm8_to_float) +{ + const size_t input_pads[4] = {0, 0, 1, 2}; + const size_t output_pads[4] = {0, 3, 2, 1}; + const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}}; + float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, -80, 90, -100}; + float scale = 10; + int32_t zero_point = 128; + uint8_t input_buffer[12]; + + int32_t min_val = std::numeric_limits<uint8_t>::min(); + int32_t max_val = std::numeric_limits<uint8_t>::max(); + for (int32_t i = 0; i < sizeof(expected_buffer) / sizeof(float); ++i) + { + int32_t unclamped = static_cast<int32_t>(std::round(expected_buffer[i] / scale)) + zero_point; + input_buffer[i] = std::min(std::max(unclamped, min_val), max_val); + } + + std::vector<std::unique_ptr<MockUpTensor>> inputs(4); + std::vector<std::unique_ptr<MockUpTensor>> outputs(4); + std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4); + for (size_t i = 0; i < 4; ++i) + { + TypeInfo type_info{DataType::QUANT_UINT8_ASYMM, scale, zero_point}; + inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]); + inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(input_buffer)); + + outputs[i] = std::make_unique<MockUpTensor>(shapes[i], TypeInfo(DataType::FLOAT32), + Layout::NHWC, output_pads[i]); + output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size()); + outputs[i]->setBuffer(output_buffers[i].get()); + } + + auto mockup_layer = std::make_unique<MockUpLayer>( + std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()}, + std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), outputs[3].get()}); + mockup_layer->run(); + + for (size_t i = 0; i < 4; ++i) + { + for (int32_t j = 0; j < shapes[i].dim(0); ++j) + { + for (int32_t k = 0; k < shapes[i].dim(1); ++k) + { + for (int32_t l = 0; l < shapes[i].dim(2); ++l) + { + for (int32_t m = 0; m < shapes[i].dim(3); ++m) + { + Coordinates coords{j, k, l, m}; + float result = + *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords)); + uint8_t qasymm8 = + *reinterpret_cast<uint8_t *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords)); + float expected = (qasymm8 - zero_point) * scale; + EXPECT_EQ(result, expected); + } + } + } + } + } +} + +TEST(IPermuteFunction, qsymm8_to_float) +{ + const size_t input_pads[4] = {0, 0, 1, 2}; + const size_t output_pads[4] = {0, 3, 2, 1}; + const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}}; + float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, -80, 90, -100}; + float scale = 10; + int32_t zero_point = 0; + uint8_t input_buffer[12]; + + int32_t min_val = std::numeric_limits<int8_t>::min(); + int32_t max_val = std::numeric_limits<int8_t>::max(); + for (int32_t i = 0; i < sizeof(expected_buffer) / sizeof(float); ++i) + { + int32_t unclamped = static_cast<int32_t>(std::round(expected_buffer[i] / scale)) + zero_point; + input_buffer[i] = std::min(std::max(unclamped, min_val), max_val); + } + + std::vector<std::unique_ptr<MockUpTensor>> inputs(4); + std::vector<std::unique_ptr<MockUpTensor>> outputs(4); + std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4); + for (size_t i = 0; i < 4; ++i) + { + TypeInfo type_info{DataType::QUANT_INT8_SYMM, scale, zero_point}; + inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]); + inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(input_buffer)); + + outputs[i] = std::make_unique<MockUpTensor>(shapes[i], TypeInfo(DataType::FLOAT32), + Layout::NHWC, output_pads[i]); + output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size()); + outputs[i]->setBuffer(output_buffers[i].get()); + } + + auto mockup_layer = std::make_unique<MockUpLayer>( + std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()}, + std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), outputs[3].get()}); + mockup_layer->run(); + + for (size_t i = 0; i < 4; ++i) + { + for (int32_t j = 0; j < shapes[i].dim(0); ++j) + { + for (int32_t k = 0; k < shapes[i].dim(1); ++k) + { + for (int32_t l = 0; l < shapes[i].dim(2); ++l) + { + for (int32_t m = 0; m < shapes[i].dim(3); ++m) + { + Coordinates coords{j, k, l, m}; + float result = + *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords)); + int8_t qasymm8 = + *reinterpret_cast<int8_t *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords)); + float expected = (qasymm8 - zero_point) * scale; + EXPECT_EQ(result, expected); + } + } + } + } + } +} + +TEST(IPermuteFunction, qsymm16_to_float) +{ + const size_t input_pads[4] = {0, 0, 1, 2}; + const size_t output_pads[4] = {0, 3, 2, 1}; + const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}}; + float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, -80, 90, -100}; + float scale = 10; + int32_t zero_point = 0; + uint8_t input_buffer[12]; + + int32_t min_val = std::numeric_limits<int16_t>::min(); + int32_t max_val = std::numeric_limits<int16_t>::max(); + for (int32_t i = 0; i < sizeof(expected_buffer) / sizeof(float); ++i) + { + int32_t unclamped = static_cast<int32_t>(std::round(expected_buffer[i] / scale)) + zero_point; + input_buffer[i] = std::min(std::max(unclamped, min_val), max_val); + } + + std::vector<std::unique_ptr<MockUpTensor>> inputs(4); + std::vector<std::unique_ptr<MockUpTensor>> outputs(4); + std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4); + for (size_t i = 0; i < 4; ++i) + { + TypeInfo type_info{DataType::QUANT_INT16_SYMM, scale, zero_point}; + inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]); + inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(input_buffer)); + + outputs[i] = std::make_unique<MockUpTensor>(shapes[i], TypeInfo(DataType::FLOAT32), + Layout::NHWC, output_pads[i]); + output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size()); + outputs[i]->setBuffer(output_buffers[i].get()); + } + + auto mockup_layer = std::make_unique<MockUpLayer>( + std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()}, + std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), outputs[3].get()}); + mockup_layer->run(); + + for (size_t i = 0; i < 4; ++i) + { + for (int32_t j = 0; j < shapes[i].dim(0); ++j) + { + for (int32_t k = 0; k < shapes[i].dim(1); ++k) + { + for (int32_t l = 0; l < shapes[i].dim(2); ++l) + { + for (int32_t m = 0; m < shapes[i].dim(3); ++m) + { + Coordinates coords{j, k, l, m}; + float result = + *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords)); + int16_t qasymm8 = + *reinterpret_cast<int16_t *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords)); + float expected = (qasymm8 - zero_point) * scale; + EXPECT_EQ(result, expected); + } + } + } + } + } +} + +TEST(IPermuteFunction, float_qasymm8_layout) +{ + // float -> quasymm8 + { + const size_t input_pads[4] = {0, 0, 1, 2}; + const size_t output_pads[4] = {0, 3, 2, 1}; + const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}}; + float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, + -80, 90, -100, 110, -120, 130, -140, 150, -160}; + float scale = 10; + int32_t zero_point = 128; + + std::vector<std::unique_ptr<MockUpTensor>> inputs(4); + std::vector<std::unique_ptr<MockUpTensor>> outputs(4); + std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4); + for (size_t i = 0; i < 4; ++i) + { + Layout layout = Layout::NHWC; + Shape shape = shapes[i]; + if (i % 2 == 1) + { + layout = Layout::NCHW; + shape = Shape{shapes[i].dim(0), shapes[i].dim(3), shapes[i].dim(1), shapes[i].dim(2)}; + } + inputs[i] = + std::make_unique<MockUpTensor>(shape, TypeInfo(DataType::FLOAT32), layout, input_pads[i]); + inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer)); + + if (layout == Layout::NHWC) + { + layout = Layout::NCHW; + shape = Shape{shapes[i].dim(0), shapes[i].dim(3), shapes[i].dim(1), shapes[i].dim(2)}; + } + else + { + layout = Layout::NHWC; + shape = shapes[i]; + } + TypeInfo type_info{DataType::QUANT_UINT8_ASYMM, scale, zero_point}; + outputs[i] = std::make_unique<MockUpTensor>(shape, type_info, layout, output_pads[i]); + output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size()); + outputs[i]->setBuffer(output_buffers[i].get()); + } + + auto mockup_layer = std::make_unique<MockUpLayer>( + std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()}, + std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), + outputs[3].get()}); + mockup_layer->run(); + + for (size_t i = 0; i < 4; ++i) + { + for (int32_t j = 0; j < shapes[i].dim(0); ++j) + { + for (int32_t k = 0; k < shapes[i].dim(1); ++k) + { + for (int32_t l = 0; l < shapes[i].dim(2); ++l) + { + for (int32_t m = 0; m < shapes[i].dim(3); ++m) + { + Coordinates input_coords; + Coordinates output_coords; + if (inputs[i]->layout() == Layout::NHWC) + { + input_coords = Coordinates{j, k, l, m}; + } + else + { + input_coords = Coordinates{j, m, k, l}; + } + if (outputs[i]->layout() == Layout::NHWC) + { + output_coords = Coordinates{j, k, l, m}; + } + else + { + output_coords = Coordinates{j, m, k, l}; + } + uint8_t qasymm8 = *reinterpret_cast<uint8_t *>(outputs[i]->buffer() + + outputs[i]->calcOffset(output_coords)); + float result = (qasymm8 - zero_point) * scale; + float expected = *reinterpret_cast<float *>(inputs[i]->buffer() + + inputs[i]->calcOffset(input_coords)); + EXPECT_EQ(result, expected); + } + } + } + } + } + } + + // qasymm8 -> float + { + const size_t input_pads[4] = {0, 0, 1, 2}; + const size_t output_pads[4] = {0, 3, 2, 1}; + const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}}; + float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, + -80, 90, -100, 110, -120, 130, -140, 150, -160}; + float scale = 10; + int32_t zero_point = 128; + uint8_t input_buffer[18]; + + int32_t min_val = std::numeric_limits<int16_t>::min(); + int32_t max_val = std::numeric_limits<int16_t>::max(); + for (int32_t i = 0; i < sizeof(expected_buffer) / sizeof(float); ++i) + { + int32_t unclamped = static_cast<int32_t>(std::round(expected_buffer[i] / scale)) + zero_point; + input_buffer[i] = std::min(std::max(unclamped, min_val), max_val); + } + + std::vector<std::unique_ptr<MockUpTensor>> inputs(4); + std::vector<std::unique_ptr<MockUpTensor>> outputs(4); + std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4); + for (size_t i = 0; i < 4; ++i) + { + Layout layout = Layout::NHWC; + Shape shape = shapes[i]; + if (i % 2 == 1) + { + layout = Layout::NCHW; + shape = Shape{shapes[i].dim(0), shapes[i].dim(3), shapes[i].dim(1), shapes[i].dim(2)}; + } + TypeInfo type_info{DataType::QUANT_UINT8_ASYMM, scale, zero_point}; + inputs[i] = std::make_unique<MockUpTensor>(shape, type_info, layout, input_pads[i]); + inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer)); + + if (layout == Layout::NHWC) + { + layout = Layout::NCHW; + shape = Shape{shapes[i].dim(0), shapes[i].dim(3), shapes[i].dim(1), shapes[i].dim(2)}; + } + else + { + layout = Layout::NHWC; + shape = shapes[i]; + } + outputs[i] = + std::make_unique<MockUpTensor>(shape, TypeInfo(DataType::FLOAT32), layout, output_pads[i]); + output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size()); + outputs[i]->setBuffer(output_buffers[i].get()); + } + + auto mockup_layer = std::make_unique<MockUpLayer>( + std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()}, + std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), + outputs[3].get()}); + mockup_layer->run(); + + for (size_t i = 0; i < 4; ++i) + { + for (int32_t j = 0; j < shapes[i].dim(0); ++j) + { + for (int32_t k = 0; k < shapes[i].dim(1); ++k) + { + for (int32_t l = 0; l < shapes[i].dim(2); ++l) + { + for (int32_t m = 0; m < shapes[i].dim(3); ++m) + { + Coordinates input_coords; + Coordinates output_coords; + if (inputs[i]->layout() == Layout::NHWC) + { + input_coords = Coordinates{j, k, l, m}; + } + else + { + input_coords = Coordinates{j, m, k, l}; + } + if (outputs[i]->layout() == Layout::NHWC) + { + output_coords = Coordinates{j, k, l, m}; + } + else + { + output_coords = Coordinates{j, m, k, l}; + } + float result = *reinterpret_cast<float *>(outputs[i]->buffer() + + outputs[i]->calcOffset(output_coords)); + uint8_t qasymm8 = *reinterpret_cast<uint8_t *>(inputs[i]->buffer() + + inputs[i]->calcOffset(input_coords)); + float expected = (qasymm8 - zero_point) * scale; + EXPECT_EQ(result, expected); + } + } + } + } + } + } +} + +} // namespace diff --git a/runtime/onert/core/src/exec/JSONExecTime.cc b/runtime/onert/core/src/exec/JSONExecTime.cc index 72a18def1..d149345fd 100644 --- a/runtime/onert/core/src/exec/JSONExecTime.cc +++ b/runtime/onert/core/src/exec/JSONExecTime.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "exec/JSONExecTime.h" -#include "backend/IConfig.h" +#include "JSONExecTime.h" + #include <fstream> namespace onert @@ -135,7 +135,7 @@ void JSON::printOperation(const std::map<uint32_t, int64_t> &operation_info, stream.seekp(-2, std::ofstream::end); } -void JSON::uploadOperationsExecTime() const +void JSON::storeOperationsExecTime() const { std::ofstream stream(_measurement_file); if (!stream.is_open()) diff --git a/runtime/onert/core/src/exec/JSONExecTime.h b/runtime/onert/core/src/exec/JSONExecTime.h index a64cb3133..e01723611 100644 --- a/runtime/onert/core/src/exec/JSONExecTime.h +++ b/runtime/onert/core/src/exec/JSONExecTime.h @@ -37,15 +37,15 @@ namespace exec * _measurements[Backend*]["string"][bool][uint32_t] = int64_t */ using MeasurementData = std::unordered_map< - const backend::Backend *, - std::unordered_map<std::string, std::unordered_map<bool, std::map<uint32_t, int64_t>>>>; + const backend::Backend *, + std::unordered_map<std::string, std::unordered_map<bool, std::map<uint32_t, int64_t>>>>; class JSON { public: explicit JSON(const std::vector<const backend::Backend *> &backends, MeasurementData &measurements) - : _measurement_file("exec_time.json"), _backends(), _measurements(measurements) + : _measurement_file("exec_time.json"), _backends(), _measurements(measurements) { for (const auto b : backends) { @@ -54,18 +54,16 @@ public: loadOperationsExecTime(); }; /** - * @brief Update _operations_exec_time_file with new data. + * @brief Update _measurement_file with new data. */ - void uploadOperationsExecTime() const; + void storeOperationsExecTime() const; private: ///@brief file containing measurements std::string _measurement_file; std::unordered_map<std::string, const backend::Backend *> _backends; - std::unordered_map< - const backend::Backend *, - std::unordered_map<std::string, std::unordered_map<bool, std::map<uint32_t, int64_t>>>> - &_measurements; + MeasurementData &_measurements; + /** * @brief Helper function for inserting data to OperationExecTimes * @@ -86,7 +84,7 @@ private: void printOperation(const std::map<uint32_t, int64_t> &operation_info, std::ofstream &stream) const; /** - * @brief Parse and load operations_exec_time from _operations_exec_time_file. + * @brief Parse and load _measurements from _measurement_file. */ void loadOperationsExecTime(); }; diff --git a/runtime/onert/core/src/exec/LinearExecutor.cc b/runtime/onert/core/src/exec/LinearExecutor.cc index 69dfe9b9b..228c4d3c0 100644 --- a/runtime/onert/core/src/exec/LinearExecutor.cc +++ b/runtime/onert/core/src/exec/LinearExecutor.cc @@ -24,41 +24,54 @@ namespace onert namespace exec { -#ifdef RUY_PROFILER -namespace -{ -char *seq_to_label(const onert::ir::OpSequence *op_seq, const onert::ir::Operations &operations) +void LinearExecutor::executeImpl(const ExecutionObservee &subject) { - auto node_name = operations.at(*op_seq->begin()).name(); - char *cstr = new char[node_name.length() + 1]; - std::strcpy(cstr, node_name.c_str()); - return cstr; -} -} // namespace + if (!subject.isEmpty() && _tracing_ctx) + { + auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_graph); + + subject.notifySubgraphBegin(profiling_subg_index); + for (auto &&code : _code) + { + const auto backend = code.lower_info->backend(); +// TODO : Move ruy profiler into ExecutionObserver +#ifdef RUY_PROFILER + ruy::profiler::ScopeLabel label(code.op->name()); #endif + subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend); -void LinearExecutor::executeImpl() -{ - _subject.notifyModelBegin(this); - for (auto &&code : _code) + auto &fn_seq = code.fn_seq; + + fn_seq->initRunning(); + + bool handle_dynamic_tensor = + _lowered_graph->getHasDynamicTensor(code.op_ind) || hasDynamicInput(); + fn_seq->enableDynamicShapeInferer(handle_dynamic_tensor); + fn_seq->run(); + + subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend); + } + subject.notifySubgraphEnd(profiling_subg_index); + } + else { - const auto op_seq = code.op_seq; - const auto backend = code.lower_info->backend(); + for (auto &&code : _code) + { // TODO : Move ruy profiler into ExecutionObserver #ifdef RUY_PROFILER - ruy::profiler::ScopeLabel label(seq_to_label(op_seq, _graph.operations())); + ruy::profiler::ScopeLabel label(code.op->name()); #endif - _subject.notifyJobBegin(this, op_seq, backend); - auto &fn_seq = code.fn_seq; - bool handle_dynamic_tensor = op_seq->has_dynamic_tensor() || hasDynamicInput(); + auto &fn_seq = code.fn_seq; - fn_seq->enableDynamicShapeInferer(handle_dynamic_tensor); - fn_seq->run(); + fn_seq->initRunning(); - _subject.notifyJobEnd(this, op_seq, backend); + bool handle_dynamic_tensor = + _lowered_graph->getHasDynamicTensor(code.op_ind) || hasDynamicInput(); + fn_seq->enableDynamicShapeInferer(handle_dynamic_tensor); + fn_seq->run(); + } } - _subject.notifyModelEnd(this); } } // namespace exec diff --git a/runtime/onert/core/src/exec/LinearExecutor.h b/runtime/onert/core/src/exec/LinearExecutor.h index c224d3f4f..853632a4e 100644 --- a/runtime/onert/core/src/exec/LinearExecutor.h +++ b/runtime/onert/core/src/exec/LinearExecutor.h @@ -22,11 +22,11 @@ #ifndef __ONERT_EXEC_EXECUTOR_H_ #define __ONERT_EXEC_EXECUTOR_H_ -#include "ir/Index.h" #include "ExecutorBase.h" -#include "compiler/Linear.h" -#include "exec/FunctionSequence.h" + #include "compiler/CodeMap.h" +#include "ir/Index.h" +#include "util/TracingCtx.h" namespace onert { @@ -44,25 +44,22 @@ public: * @brief Construct a new LinearExecutor object * @param lowered_graph LoweredGraph object * @param tensor_builders Tensor builders that are currently used - * @param code_map OpSequence and its code map + * @param code_map @c ir::Operation and its code map */ LinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, - const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors, - const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors, - const compiler::TensorRegistries &tensor_regs, - backend::TensorManagerSet &&tensor_mgrs, compiler::CodeMap &&code_map, - const std::vector<ir::OpSequenceIndex> &order) - : ExecutorBase{std::move(lowered_graph), input_tensors, output_tensors, tensor_regs, - std::move(tensor_mgrs)} + backend::BackendContexts &&backend_contexts, + const compiler::TensorRegistries &tensor_regs, compiler::CodeMap &&code_map, + const std::vector<ir::OperationIndex> &order, const util::TracingCtx *tracing_ctx) + : ExecutorBase{std::move(lowered_graph), std::move(backend_contexts), tensor_regs, tracing_ctx} { - for (auto index : order) + for (auto &&index : order) { _code.emplace_back(std::move(code_map.at(index))); } } public: - void executeImpl(void) override; + void executeImpl(const ExecutionObservee &subject) override; private: std::vector<compiler::CodeAndInfo> _code; diff --git a/runtime/onert/core/src/exec/MinMaxData.cc b/runtime/onert/core/src/exec/MinMaxData.cc new file mode 100644 index 000000000..1d18252e8 --- /dev/null +++ b/runtime/onert/core/src/exec/MinMaxData.cc @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MinMaxData.h" + +#include <iostream> + +namespace onert +{ +namespace exec +{ + +RawMinMaxDumper::RawMinMaxDumper(const std::string &filename) : _filename(filename) {} + +void RawMinMaxDumper::dump(const exec::IOMinMaxMap &input_minmax, + const exec::OpMinMaxMap &op_minmax) const +{ + // Find file is already exist for modifying + auto file = std::fopen(_filename.c_str(), "rb+"); + uint32_t runs = 1; + + // Magic code and version + // Match with runtime/onert/odc/MinMaxReader.cc + // TODO Use util to share code and version + const uint32_t MAGIC_CODE = 0x4F4D4D44; + const uint32_t VERSION = 1; + if (!file) + { + // If file is not exist, create new file + file = std::fopen(_filename.c_str(), "wb+"); + if (!file) + throw std::runtime_error{"RawMinMaxDumper: Failed to open minmax file " + _filename}; + + // Write magic code and version + std::fwrite(&MAGIC_CODE, sizeof(uint32_t), 1, file); + std::fwrite(&VERSION, sizeof(uint32_t), 1, file); + } + else + { + // Check magic code and version + std::fseek(file, 0, SEEK_SET); + uint32_t read_magic_code = 0; + uint32_t read_version = 0; + bool rewrite = true; + if (std::fread(&read_magic_code, sizeof(uint32_t), 1, file) == 1 && + read_magic_code == MAGIC_CODE && + std::fread(&read_version, sizeof(uint32_t), 1, file) == 1 && read_version == VERSION) + rewrite = false; + + // Destroy and create if file is not valid + if (rewrite) + { + std::fclose(file); + file = std::fopen(_filename.c_str(), "wb+"); + if (!file) + throw std::runtime_error{"RawMinMaxDumper: Failed to rewrite minmax file " + _filename}; + + // Write magic code and version + std::fwrite(&MAGIC_CODE, sizeof(uint32_t), 1, file); + std::fwrite(&VERSION, sizeof(uint32_t), 1, file); + } + } + + // Read run count + if (std::fread(&runs, sizeof(uint32_t), 1, file) == 1) + runs++; + else + runs = 1; + + // TODO Verify file size + + // Overwrite run count + std::fseek(file, sizeof(MAGIC_CODE) + sizeof(VERSION), SEEK_SET); + std::fwrite(&runs, sizeof(uint32_t), 1, file); + + // Go to end of file to append new data + std::fseek(file, 0, SEEK_END); + + uint32_t input_count = input_minmax.size(); + uint32_t op_count = op_minmax.size(); + + // Write op_count and input_count + std::fwrite(&op_count, sizeof(uint32_t), 1, file); + std::fwrite(&input_count, sizeof(uint32_t), 1, file); + + // For each op + for (auto &&elem : op_minmax) + { + const uint32_t model_idx = 0; + const uint32_t subg_idx = elem.first.first.value(); + const uint32_t op_idx = elem.first.second.value(); + + // Write model/subg/op index + std::fwrite(&model_idx, sizeof(uint32_t), 1, file); + std::fwrite(&subg_idx, sizeof(uint32_t), 1, file); + std::fwrite(&op_idx, sizeof(uint32_t), 1, file); + + // Write min/max + std::fwrite(elem.second.data, sizeof(float), 2, file); + } + + // For each input + for (auto &&elem : input_minmax) + { + const uint32_t model_idx = 0; + const uint32_t subg_idx = elem.first.first.value(); + const uint32_t input_idx = elem.first.second.value(); + + // Write model/subg/input index + std::fwrite(&model_idx, sizeof(uint32_t), 1, file); + std::fwrite(&subg_idx, sizeof(uint32_t), 1, file); + std::fwrite(&input_idx, sizeof(uint32_t), 1, file); + + // Write min/max + std::fwrite(elem.second.data, sizeof(float), 2, file); + } + + std::fclose(file); +} + +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/MinMaxData.h b/runtime/onert/core/src/exec/MinMaxData.h new file mode 100644 index 000000000..2538d444c --- /dev/null +++ b/runtime/onert/core/src/exec/MinMaxData.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_MINMAX_DATA_H__ +#define __ONERT_EXEC_MINMAX_DATA_H__ + +#include "exec/MinMaxMap.h" + +#include <string> + +namespace onert +{ +namespace exec +{ + +// Because IOMinMaxMap and OpMinMaxMap does not have the ordering and size information, +// we need to dump model, subgraph id for each minmax + +// File structure +// uint32_t magic code +// uint32_t version +// uint32_t num of runs + +// For each run +// uint32_t num of operations +// uint32_t num of inputs + +// For each operation +// uint32_t model id +// uint32_t subgraph id +// uint32_t operation id +// float min +// float max + +// For each input +// uint32_t model id +// uint32_t subgraph id +// uint32_t input id +// float min +// float max + +class RawMinMaxDumper +{ +public: + RawMinMaxDumper(const std::string &filename); + /** + * @brief Dump input minmax map + * + * @param[in] in_minmax input minmax map + * @param[in] op_minmax op minmax map + */ + + void dump(const exec::IOMinMaxMap &in_minmax, const exec::OpMinMaxMap &op_minmax) const; + +private: + std::string _filename; +}; + +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_MINMAX_DATA_H__ diff --git a/runtime/onert/core/src/exec/MinMaxRecorder.cc b/runtime/onert/core/src/exec/MinMaxRecorder.cc new file mode 100644 index 000000000..179800011 --- /dev/null +++ b/runtime/onert/core/src/exec/MinMaxRecorder.cc @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MinMaxRecorder.h" +#if MINMAX_H5DUMPER +#include "../dumper/h5/MinMaxDumper.h" +#else +#include "MinMaxData.h" +#endif +#include "backend/ITensor.h" + +#include <cassert> +#include <cmath> + +namespace onert +{ +namespace exec +{ + +MinMaxRecorder::MinMaxRecorder(const std::string &workspace_dir, const ir::Graph &graph, + const backend::BackendContexts &backend_contexts) + : _graph{graph}, _backend_contexts{backend_contexts}, _workspace_dir(workspace_dir) +{ + // DO NOTHING +} + +std::pair<float, float> minmaxFrom(const backend::ITensor *tensor) +{ + const auto data = reinterpret_cast<float *>(tensor->buffer()); + const auto num_elements = tensor->total_size() / sizeof(float); + + float max = std::numeric_limits<float>::lowest(); + float min = std::numeric_limits<float>::max(); + + bool all_nan = true; + for (size_t i = 0; i < num_elements; ++i) + { + const float number = data[i]; + if (std::isnan(number)) + continue; + + if (number == std::numeric_limits<float>::lowest()) + continue; + + all_nan = false; + + if (number > max) + max = number; + + if (number < min) + min = number; + } + + if (all_nan) + throw std::runtime_error("All values are NaN(Not a Number)"); + + return {min, max}; +} + +void MinMaxRecorder::handleJobEnd(IExecutor *, ir::SubgraphIndex subg_idx, + ir::OperationIndex op_idx, const backend::Backend *backend) +{ + const auto &tensor_reg = _backend_contexts.at(backend)->tensor_registry; + const auto &op = _graph.operations().at(op_idx); + const auto &outputs = op.getOutputs(); + // TODO: Support multiple output + if (outputs.size() != 1) + throw std::runtime_error("Only 1 output operator is supported for recording minmax."); + + auto tensor = tensor_reg->getITensor(outputs.at(0)); + + // Logic copied from MinMaxObserver.cpp. + + // Filter Ops + if (tensor->is_constant()) + return; + + if (tensor->data_type() != ir::DataType::FLOAT32) + return; + + switch (op.opcode()) + { + // Operators with multiple outputs + case ir::OpCode::If: + case ir::OpCode::Split: + case ir::OpCode::SplitV: + case ir::OpCode::TopKV2: + case ir::OpCode::Unpack: + case ir::OpCode::While: + return; + // NOTE: Sin, Cos, Tanh's output is in [-1, 1] + // We may not need to dump those operators. + default:; // Do Nothing + } + + // Otherwise, dump! + assert(tensor->data_type() == ir::DataType::FLOAT32); + auto minmax = minmaxFrom(tensor); + _op_minmax.append({subg_idx, op_idx}, minmax.first, minmax.second); +} + +void MinMaxRecorder::handleSubgraphBegin(ir::SubgraphIndex subg_idx) +{ + // Make sure there is only cpu backend except for builtin backend + std::set<std::string> backend_names; + backend::ITensorRegistry *tensor_reg = nullptr; + for (const auto &pair : _backend_contexts) + { + backend_names.insert(pair.first->config()->id()); + if (pair.first->config()->id() == "cpu") + { + tensor_reg = pair.second->tensor_registry.get(); + } + } + if (backend_names != std::set<std::string>{"builtin", "cpu"}) + throw std::runtime_error("MinMaxRecorder must have cpu backend only."); + + const auto &inputs = _graph.getInputs(); //.at(op_idx); + for (uint32_t i = 0; i < inputs.size(); ++i) + { + auto input_idx = inputs.at(i); + auto tensor = tensor_reg->getITensor(input_idx); + + if (tensor->is_constant()) + return; + if (tensor->data_type() != ir::DataType::FLOAT32) + return; + + auto minmax = minmaxFrom(tensor); + _input_minmax.append({subg_idx, ir::IOIndex{i}}, minmax.first, minmax.second); + } +} + +void MinMaxRecorder::handleSubgraphEnd(ir::SubgraphIndex) +{ + // It would be better to dump at the end of model execution, not subgraph + // But it requires more changes than subgraph. +#if MINMAX_H5DUMPER + auto h5dumper = dumper::h5::MinMaxDumper(_workspace_dir + "/minmax.h5"); + h5dumper.dump(_input_minmax, _op_minmax); +#else + auto raw_dumper = RawMinMaxDumper(_workspace_dir + "/minmax.bin"); + raw_dumper.dump(_input_minmax, _op_minmax); +#endif +} + +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/MinMaxRecorder.h b/runtime/onert/core/src/exec/MinMaxRecorder.h new file mode 100644 index 000000000..ed5163972 --- /dev/null +++ b/runtime/onert/core/src/exec/MinMaxRecorder.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_MINMAX_RECORDER__ +#define __ONERT_EXEC_MINMAX_RECORDER__ + +#include "ExecutionObservers.h" +#include "ir/Index.h" +#include "exec/MinMaxMap.h" + +#include <string> + +namespace onert +{ +namespace exec +{ + +class MinMaxRecorder : public IExecutionObserver +{ +public: + MinMaxRecorder(const std::string &workspace_dir, const ir::Graph &graph, + const backend::BackendContexts &backend_contexts); + void handleJobBegin(IExecutor *, ir::SubgraphIndex, ir::OperationIndex, + const backend::Backend *) override + { + return; + } + void handleJobEnd(IExecutor *, ir::SubgraphIndex, ir::OperationIndex, + const backend::Backend *) override; + void handleSubgraphBegin(ir::SubgraphIndex) override; + void handleSubgraphEnd(ir::SubgraphIndex) override; + ObserverType type() const override { return ObserverType::MINMAX_DUMP; } + +private: + const ir::Graph &_graph; + const backend::BackendContexts &_backend_contexts; + std::string _workspace_dir; + OpMinMaxMap _op_minmax; + IOMinMaxMap _input_minmax; +}; + +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_MINMAX_RECORDER__ diff --git a/runtime/onert/core/src/exec/MultiModelExecutors.cc b/runtime/onert/core/src/exec/MultiModelExecutors.cc new file mode 100644 index 000000000..920b17d45 --- /dev/null +++ b/runtime/onert/core/src/exec/MultiModelExecutors.cc @@ -0,0 +1,589 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MultiModelExecutors.h" + +namespace +{ + +using namespace onert; + +int32_t find_input_index(const std::vector<ir::IODesc> &pkg_inputs, + const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index, + const ir::IOIndex &io_index) +{ + for (size_t i = 0; i < pkg_inputs.size(); i++) + { + auto &input_desc = pkg_inputs[i]; + if ((std::get<ir::ModelIndex>(input_desc) == model_index) && + (std::get<ir::SubgraphIndex>(input_desc) == subg_index) && + (std::get<ir::IOIndex>(input_desc) == io_index)) + return static_cast<int32_t>(i); + } + return -1; +} + +int32_t find_output_index(const std::vector<ir::IODesc> &pkg_outputs, + const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index, + const ir::IOIndex &io_index) +{ + for (size_t i = 0; i < pkg_outputs.size(); i++) + { + auto &input_desc = pkg_outputs[i]; + if ((std::get<ir::ModelIndex>(input_desc) == model_index) && + (std::get<ir::SubgraphIndex>(input_desc) == subg_index) && + (std::get<ir::IOIndex>(input_desc) == io_index)) + return static_cast<int32_t>(i); + } + return -1; +} + +} // namespace + +namespace onert +{ +namespace exec +{ + +void MultiModelExecutors::emplace(const ir::ModelIndex &model_index, + const ir::SubgraphIndex &subg_index, + std::unique_ptr<IExecutor> exec) +{ + _executors.emplace(std::make_pair(model_index, subg_index), std::move(exec)); +} + +IExecutor *MultiModelExecutors::at(const ir::ModelIndex &model_index, + const ir::SubgraphIndex &subg_index) const +{ + return _executors.at(std::make_pair(model_index, subg_index)).get(); +} + +uint32_t MultiModelExecutors::inputSize() const { return _model_edges->pkg_inputs.size(); } + +uint32_t MultiModelExecutors::outputSize() const { return _model_edges->pkg_outputs.size(); } + +const ir::OperandInfo &MultiModelExecutors::inputInfo(const ir::IOIndex &index) const +{ + auto const desc = _model_edges->pkg_inputs[index.value()]; + auto const model_index = std::get<0>(desc); + auto const subg_index = std::get<1>(desc); + auto const io_index = std::get<2>(desc); + auto const executor = at(model_index, subg_index); + return executor->inputInfo(io_index.value()); +} + +const ir::OperandInfo &MultiModelExecutors::outputInfo(const ir::IOIndex &index) const +{ + auto const desc = _model_edges->pkg_outputs[index.value()]; + auto const model_index = std::get<0>(desc); + auto const subg_index = std::get<1>(desc); + auto const io_index = std::get<2>(desc); + auto const executor = at(model_index, subg_index); + return executor->outputInfo(io_index.value()); +} + +// Allow below edges only +// m1 < m2, s1 == 0 and s2 == 0 if m1:s1:o1 -> m2:s2:o2' +void MultiModelExecutors::checkSupportedMultimodel() const +{ + // If package includes no-connection model, model_count is less than real model count in package. + // Then this method will throw exception based on model index + // 1st model: input assumption + // Otherwise: edges assumption + + // Assumption: edges + // m1 < m2, s1 == 0 and s2 == 0 if edge 'm1:s1:o1 -> m2:s2:o2' + for (auto &&edge : _model_edges->edges) + { + auto const model_from = std::get<ir::ModelIndex>(edge.from); + auto const model_to = std::get<ir::ModelIndex>(edge.to); + auto const subg_from = std::get<ir::SubgraphIndex>(edge.from); + auto const subg_to = std::get<ir::SubgraphIndex>(edge.to); + + if (model_from.value() == model_to.value()) + { + throw std::runtime_error{"Multi model's edge set has invalid edge"}; + } + + if ((model_from.value() > model_to.value()) || (subg_from != ir::SubgraphIndex{0}) || + (subg_to != ir::SubgraphIndex{0})) + throw std::runtime_error{"NYI: Multi model execution for this edge set is not supported yet"}; + } + + // Assumption: package inputs + // All 1st model inputs come from package input if always m1 < m2 + { + auto first_executor = at(ir::ModelIndex{0}, ir::SubgraphIndex{0}); + auto search_first_model = [&](const ir::IOIndex &input_index) { + for (const auto &input : _model_edges->pkg_inputs) + { + if ((std::get<ir::ModelIndex>(input) == ir::ModelIndex{0}) || + (std::get<ir::SubgraphIndex>(input) == ir::SubgraphIndex{0}) || + (std::get<ir::IOIndex>(input) == input_index)) + return true; + } + + return false; + }; + + for (uint32_t i = 0; i < first_executor->inputSize(); i++) + { + if (!search_first_model(ir::IOIndex{i})) + throw std::runtime_error{"Cannot find 1st model's input buffer"}; + } + } + + // Check whether nnpkg outputs and Edge `from` are duplicated + for (const auto &edge : _model_edges->edges) + { + if (std::find(_model_edges->pkg_outputs.begin(), _model_edges->pkg_outputs.end(), edge.from) != + _model_edges->pkg_outputs.end()) + { + throw std::runtime_error{"Multi model execution does not support duplicating nnpkg outputs " + "with `from` of edges yet"}; + } + } +} + +void MultiModelExecutors::createEdgeQuantLayers() +{ + if (_is_created_edge_quant_layers) + { + return; + } + + // Create EdgeTensor for edges between executors + for (const auto &pair : _edge_map) + { + const auto &from_iodesc = pair.first; + const auto &from_model_index = std::get<ir::ModelIndex>(from_iodesc); + const auto &from_subg_index = std::get<ir::SubgraphIndex>(from_iodesc); + const auto &from_io_index = std::get<ir::IOIndex>(from_iodesc); + + const auto from_executor = _executors.at({from_model_index, from_subg_index}).get(); + const auto &from_info = from_executor->inputInfo(from_io_index.value()); + const auto from_layout = from_executor->inputLayout(from_io_index.value()); + _edge_tensors[from_iodesc] = std::make_unique<EdgeTensor>(from_info, from_layout); + } + + // Append type-aware quantization layer for edges between executors + for (const auto &executor_pair : _executors) + { + const auto &executor_index = executor_pair.first; + const auto &model_index = executor_index.first; + const auto &subg_index = executor_index.second; + + std::vector<backend::ITensor *> inputs; + std::vector<backend::ITensor *> outputs; + for (const auto &pair : _edge_map) + { + const auto &from_iodesc = pair.first; + if (std::get<ir::ModelIndex>(from_iodesc) == model_index && + std::get<ir::SubgraphIndex>(from_iodesc) == subg_index) + { + const auto from_tensor = _edge_tensors[from_iodesc].get(); + const auto &to_list = pair.second; + + for (const auto &to_iodesc : to_list) + { + const auto &to_model_index = std::get<ir::ModelIndex>(to_iodesc); + const auto &to_subg_index = std::get<ir::SubgraphIndex>(to_iodesc); + const auto &to_io_index = std::get<ir::IOIndex>(to_iodesc); + + const auto to_executor = _executors.at({to_model_index, to_subg_index}).get(); + const auto &to_info = to_executor->inputInfo(to_io_index.value()); + const auto to_layout = to_executor->inputLayout(to_io_index.value()); + + // TODO Unify tensors with the same `from` tensor and same type + if (from_tensor->data_type() != to_info.typeInfo().type()) + { + assert(inputs.size() == outputs.size()); + inputs.emplace_back(from_tensor); + + auto type_aware_quant_tensor = std::make_unique<EdgeTensor>(to_info, to_layout); + outputs.emplace_back(type_aware_quant_tensor.get()); + + _edge_quant_tensors[to_iodesc] = std::move(type_aware_quant_tensor); + } + } + } + } + + auto layer = std::make_unique<PermuteLayer>(inputs, outputs); + layer->prepare(); + _edge_quant_layers[{model_index, subg_index}] = std::move(layer); + } + + _is_created_edge_quant_layers = true; +} + +void MultiModelExecutors::CreatePkgIOTensors(const IODescription &desc) +{ + for (const auto &pkg_input : _model_edges->pkg_inputs) + { + // Create IOTensor for nnpkg inputs + const auto &model_index = std::get<ir::ModelIndex>(pkg_input); + const auto &subg_index = std::get<ir::SubgraphIndex>(pkg_input); + const auto &io_index = std::get<ir::IOIndex>(pkg_input); + const auto input_pkg_index = + find_input_index(_model_edges->pkg_inputs, model_index, subg_index, io_index); + if (input_pkg_index == -1) + throw std::runtime_error{"Cannot find multi model input index"}; + auto input_desc = desc.inputs[input_pkg_index].get(); + // TODO Remove const_cast (we need const_cast as ITensor is writable) + _pkg_input_tensors[pkg_input] = std::make_unique<backend::builtin::UserTensor>( + input_desc->info, input_desc->layout, + const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(input_desc->buffer)), + input_desc->size); + } + + for (const auto &pkg_output : _model_edges->pkg_outputs) + { + // Create IOTensor for nnpkg outputs + const auto &model_index = std::get<ir::ModelIndex>(pkg_output); + const auto &subg_index = std::get<ir::SubgraphIndex>(pkg_output); + const auto &io_index = std::get<ir::IOIndex>(pkg_output); + const auto output_pkg_index = + find_output_index(_model_edges->pkg_outputs, model_index, subg_index, io_index); + if (output_pkg_index == -1) + throw std::runtime_error{"Cannot find multi model output index"}; + auto output_desc = desc.outputs[output_pkg_index].get(); + _pkg_output_tensors[pkg_output] = std::make_unique<backend::builtin::UserTensor>( + output_desc->info, output_desc->layout, reinterpret_cast<uint8_t *>(output_desc->buffer), + output_desc->size); + } +} + +void MultiModelExecutors::createPkgIOQuantLayers(const IODescription &desc) +{ + // Append type-aware quantization layer for nnpkg inputs/outputs between executors + for (const auto &pair : _executors) + { + const auto &executor_index = pair.first; + const auto &model_index = executor_index.first; + const auto &subg_index = executor_index.second; + const auto executor = pair.second.get(); + + // Find pkg inputs of current executor + std::vector<ir::IODesc> pkg_inputs; + for (const auto &pkg_input : _model_edges->pkg_inputs) + { + if (std::get<ir::ModelIndex>(pkg_input) == model_index && + std::get<ir::SubgraphIndex>(pkg_input) == subg_index) + { + pkg_inputs.emplace_back(pkg_input); + } + } + std::vector<backend::ITensor *> src_tensors; + std::vector<backend::ITensor *> dst_tensors; + for (const auto &pkg_input : pkg_inputs) + { + const auto &io_index = std::get<ir::IOIndex>(pkg_input); + const auto input_pkg_index = + find_input_index(_model_edges->pkg_inputs, model_index, subg_index, io_index); + if (input_pkg_index == -1) + throw std::runtime_error{"Cannot find multi model input index"}; + auto input_desc = desc.inputs[input_pkg_index].get(); + + // Create EdgeTensor for nnpkg input if type is different + const auto &orig_info = executor->inputInfo(io_index.value()); + const auto orig_layout = executor->inputLayout(io_index.value()); + if (input_desc->info.typeInfo().type() != orig_info.typeInfo().type()) + { + auto pkg_input_edge_tensor = std::make_unique<EdgeTensor>(orig_info, orig_layout); + _pkg_input_quant_tensors[pkg_input] = std::move(pkg_input_edge_tensor); + + // Append type-aware quantization layer's inputs/outputs + src_tensors.emplace_back(_pkg_input_tensors[pkg_input].get()); + dst_tensors.emplace_back(_pkg_input_quant_tensors[pkg_input].get()); + } + } + + // Create type-aware quantization layer for nnpkg inputs + auto pkg_input_layer = std::make_unique<PermuteLayer>(src_tensors, dst_tensors); + pkg_input_layer->prepare(); + _pkg_input_quant_layers[{model_index, subg_index}] = std::move(pkg_input_layer); + + // Find pkg outputs of current executor + std::vector<ir::IODesc> pkg_outputs; + for (const auto &pkg_output : _model_edges->pkg_outputs) + { + if (std::get<ir::ModelIndex>(pkg_output) == model_index && + std::get<ir::SubgraphIndex>(pkg_output) == subg_index) + { + pkg_outputs.emplace_back(pkg_output); + } + } + src_tensors.clear(); + dst_tensors.clear(); + // Create Tensors of nnpkg outputs for type-aware quantization + for (const auto &pkg_output : pkg_outputs) + { + const auto &io_index = std::get<ir::IOIndex>(pkg_output); + const auto output_pkg_index = + find_output_index(_model_edges->pkg_outputs, model_index, subg_index, io_index); + if (output_pkg_index == -1) + throw std::runtime_error{"Cannot find multi model output index"}; + auto output_desc = desc.outputs[output_pkg_index].get(); + + // Create EdgeTensor for nnpkg output if type is different + const auto &orig_info = executor->outputInfo(io_index.value()); + const auto orig_layout = executor->outputLayout(io_index.value()); + if (output_desc->info.typeInfo().type() != orig_info.typeInfo().type()) + { + auto pkg_output_edge_tensor = std::make_unique<EdgeTensor>(orig_info, orig_layout); + _pkg_output_quant_tensors[pkg_output] = std::move(pkg_output_edge_tensor); + + // Append type-aware quantization layer's inputs/outputs + src_tensors.emplace_back(_pkg_output_quant_tensors[pkg_output].get()); + dst_tensors.emplace_back(_pkg_output_tensors[pkg_output].get()); + } + } + + // Create type-aware quantization layer for nnpkg outputs + auto pkg_output_layer = std::make_unique<PermuteLayer>(src_tensors, dst_tensors); + pkg_output_layer->prepare(); + _pkg_output_quant_layers[{model_index, subg_index}] = std::move(pkg_output_layer); + } +} + +void MultiModelExecutors::execute(const ExecutionContext &ctx) +{ + auto &desc = ctx.desc; + + // Check supported multi model package + checkSupportedMultimodel(); + + // TODO Move creating type-aware quantization layers for edges in compilation stage + createEdgeQuantLayers(); + + // TODO Create IOTensors only once and recreate them only if nnpkg info changes + CreatePkgIOTensors(desc); + + // TODO Create type-aware quantization layers only once and recreate them only if type changes + createPkgIOQuantLayers(desc); + + // TODO Find better way to schedule order of executors + auto const model_count = modelCount(); + + auto find_from = [&](const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index, + const ir::IOIndex &io_index) { + for (const auto &edge : _model_edges->edges) + { + if ((std::get<ir::ModelIndex>(edge.to) == model_index) && + (std::get<ir::SubgraphIndex>(edge.to) == subg_index) && + (std::get<ir::IOIndex>(edge.to) == io_index)) + return edge.from; + } + + throw std::runtime_error{"Cannot find edge for model input"}; + }; + + // Execute each model + // NOTE May be better to use vector instead of unordered_map for _executors + for (auto model_index = ir::ModelIndex{0}; model_index.value() < model_count; model_index++) + { + // Find executor + auto executor = at(model_index, ir::SubgraphIndex{0}); + + // Set IOTensors + // TODO Set internal IOTensors only once + std::vector<backend::IPortableTensor *> inputs_inter; + std::vector<backend::IPortableTensor *> outputs_inter; + auto const input_size = executor->inputSize(); + auto const output_size = executor->outputSize(); + inputs_inter.resize(input_size); + outputs_inter.resize(output_size); + + // Set inputs of executor + // TODO Create layer to allocate/deallocate buffers of EdgeTensor for each executor + for (uint32_t i = 0; i < input_size; i++) + { + const auto input_pkg_index = find_input_index(_model_edges->pkg_inputs, model_index, + ir::SubgraphIndex{0}, ir::IOIndex{i}); + const auto input_io_desc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}}; + if (input_pkg_index != -1) + { + // Allocate type-aware quantization tensors for nnpkg inputs and set internal tensors + if (_pkg_input_quant_tensors.find(input_io_desc) != _pkg_input_quant_tensors.end()) + { + _pkg_input_quant_tensors[input_io_desc]->allocate_buffer(); + + inputs_inter[i] = _pkg_input_quant_tensors[input_io_desc].get(); + } + else + { + inputs_inter[i] = _pkg_input_tensors[input_io_desc].get(); + } + } + else + { + auto from_iodesc = find_from(model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}); + + // Supported only sequantial execution of models + assert(std::get<ir::ModelIndex>(from_iodesc).value() < model_index.value()); + assert(std::get<ir::SubgraphIndex>(from_iodesc).value() == 0); + const auto to_iodesc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}}; + if (_edge_quant_tensors.find(to_iodesc) == _edge_quant_tensors.end()) + { + inputs_inter[i] = _edge_tensors.at(from_iodesc).get(); + } + else + { + inputs_inter[i] = _edge_quant_tensors.at(to_iodesc).get(); + } + assert(inputs_inter[i]->buffer() != nullptr); + } + } + + // Set outputs of executor + for (uint32_t i = 0; i < output_size; i++) + { + const auto output_pkg_index = find_output_index(_model_edges->pkg_outputs, model_index, + ir::SubgraphIndex{0}, ir::IOIndex{i}); + const auto output_io_desc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}}; + if (output_pkg_index != -1) + { + // Allocate type-aware quantization tensors for nnpkg outputs and set internal tensors + if (_pkg_output_quant_tensors.find(output_io_desc) != _pkg_output_quant_tensors.end()) + { + _pkg_output_quant_tensors[output_io_desc]->allocate_buffer(); + + outputs_inter[i] = _pkg_output_quant_tensors[output_io_desc].get(); + } + else + { + outputs_inter[i] = _pkg_output_tensors[output_io_desc].get(); + } + } + else + { + // Allocate buffer of `from` tensors + const auto from_iodesc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}}; + _edge_tensors[from_iodesc]->allocate_buffer(); + outputs_inter[i] = _edge_tensors[from_iodesc].get(); + + // Allocate buffer of tensors for type-aware quantization + for (const auto &to_iodesc : _edge_map[from_iodesc]) + { + _edge_tensors[from_iodesc]->increase_ref(); + if (_edge_quant_tensors.find(to_iodesc) != _edge_quant_tensors.end()) + { + auto type_aware_quant_tensor = _edge_quant_tensors.at(to_iodesc).get(); + type_aware_quant_tensor->allocate_buffer(); + + _edge_tensors[from_iodesc]->decrease_ref(); + } + } + } + } + + _pkg_input_quant_layers[{model_index, ir::SubgraphIndex{0}}]->run(); + + executor->execute(inputs_inter, outputs_inter, ctx.options); + + _edge_quant_layers[{model_index, ir::SubgraphIndex{0}}]->run(); + _pkg_output_quant_layers[{model_index, ir::SubgraphIndex{0}}]->run(); + + // Release input buffers that are no longer needed + for (uint32_t i = 0; i < input_size; i++) + { + const auto input_pkg_index = find_input_index(_model_edges->pkg_inputs, model_index, + ir::SubgraphIndex{0}, ir::IOIndex{i}); + + const auto to_iodesc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}}; + if (input_pkg_index == -1) + { + if (_edge_quant_tensors.find(to_iodesc) != _edge_quant_tensors.end()) + { + // Decrease reference count of tensor for type-aware quantization if input tensor is the + // tensor + const auto to_iodesc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}}; + if (_edge_quant_tensors.find(to_iodesc) != _edge_quant_tensors.end()) + { + _edge_quant_tensors[to_iodesc]->decrease_ref(); + } + } + else + { + // Decrease reference count of `from` tensor if input tensor is the `from` tensor + const auto from_iodesc = find_from(model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}); + _edge_tensors[from_iodesc]->decrease_ref(); + + // Decrease reference count of nnpkg inputs + if (_pkg_input_quant_tensors.find(to_iodesc) != _pkg_input_quant_tensors.end()) + { + _pkg_input_quant_tensors[to_iodesc]->decrease_ref(); + } + } + } + } + + // Release output buffers if those buffers are no longer used other executors because of + // type-aware quantization + // FIXME if tensors for type-aware quantization unified for the same `from` tensor and same type + for (uint32_t i = 0; i < output_size; i++) + { + auto from_iodesc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}}; + + // Check if other executors will use the buffer of edge tensor + const auto &to_list = _edge_map[from_iodesc]; + if (to_list.size() == 0) + { + // This condition means `from_iodesc` tensor is an output of nnpkg + continue; + } + + bool to_be_release = + !std::any_of(to_list.begin(), to_list.end(), [&](const ir::IODesc &to_iodesc) { + // This condition means another executor uses the buffer of edge tensor + return _edge_quant_tensors.find(to_iodesc) == _edge_quant_tensors.end(); + }); + + if (to_be_release) + { + // This edge tensor's buffer won't be used in other executors + // Tensors for type-aware quantization take over the role of this edge tensor instead + _edge_tensors[from_iodesc]->decrease_ref(); + } + + // Decrease reference count of nnpkg outputs + if (_pkg_output_quant_tensors.find(from_iodesc) != _pkg_output_quant_tensors.end()) + { + _pkg_output_quant_tensors[from_iodesc]->decrease_ref(); + } + } + } +} + +// modelCount() iterates _executors. +// It assumes that Compiler will generate Executor for all models and _executors includes all +// generated Executor. +// If nnpackage includes model(s) which has no connection and Compiler does not +// generate Executor for them, modelCount() return less value than real model count. +uint16_t MultiModelExecutors::modelCount() const +{ + uint16_t model_count = 0; + for (; _executors.find(std::make_pair(ir::ModelIndex{model_count}, ir::SubgraphIndex{0})) != + _executors.end(); + model_count++) + ; + + return model_count; +} + +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/MultiModelExecutors.h b/runtime/onert/core/src/exec/MultiModelExecutors.h new file mode 100644 index 000000000..0bd9f1143 --- /dev/null +++ b/runtime/onert/core/src/exec/MultiModelExecutors.h @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_EXECUTORS_H__ +#define __ONERT_EXEC_EXECUTORS_H__ + +#include "exec/IExecutors.h" +#include "ir/NNPkg.h" +#include "IPermuteFunction.h" +#include "EdgeTensor.h" +#include "../backend/builtin/UserTensor.h" + +namespace std +{ + +template <> struct hash<std::pair<::onert::ir::ModelIndex, ::onert::ir::SubgraphIndex>> +{ + size_t operator()( + const std::pair<::onert::ir::ModelIndex, ::onert::ir::SubgraphIndex> &pair) const noexcept + { + return (hash<uint32_t>()(pair.first.value()) << 16) ^ hash<uint32_t>()(pair.second.value()); + } +}; + +} // namespace std + +namespace onert +{ +namespace exec +{ + +/** + * @brief Class to gather executors + */ +class MultiModelExecutors : public IExecutors +{ +public: + MultiModelExecutors(void) = delete; + MultiModelExecutors(std::unique_ptr<ir::ModelEdges> model_edges) + : _executors{}, _model_edges{std::move(model_edges)}, _edge_quant_layers{}, + _edge_quant_tensors{}, _edge_tensors{}, _is_created_edge_quant_layers{false}, + _pkg_input_quant_layers{}, _pkg_output_quant_layers{}, _pkg_input_quant_tensors{}, + _pkg_output_quant_tensors{}, _pkg_input_tensors{}, _pkg_output_tensors{} + { + for (const auto &edge : _model_edges->edges) + { + _edge_map[edge.from].emplace_back(edge.to); + } + } + MultiModelExecutors(const MultiModelExecutors &) = delete; + MultiModelExecutors(MultiModelExecutors &&) = default; + ~MultiModelExecutors() = default; + + // TODO Use Executor index + void emplace(const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index, + std::unique_ptr<IExecutor> exec) override; + + IExecutor *at(const ir::ModelIndex &model_index, + const ir::SubgraphIndex &subg_index) const override; + + uint32_t inputSize() const override; + + uint32_t outputSize() const override; + + const ir::OperandInfo &inputInfo(const ir::IOIndex &index) const override; + + const ir::OperandInfo &outputInfo(const ir::IOIndex &index) const override; + + void execute(const ExecutionContext &ctx) override; + +private: + void checkSupportedMultimodel() const; + void createEdgeQuantLayers(); + void CreatePkgIOTensors(const IODescription &desc); + void createPkgIOQuantLayers(const IODescription &desc); + uint16_t modelCount() const; + +private: + std::unordered_map<std::pair<ir::ModelIndex, ir::SubgraphIndex>, std::unique_ptr<IExecutor>> + _executors; + + // NOTE _model_edges may use different struct type for executor implementation + std::unique_ptr<ir::ModelEdges> _model_edges; + std::unordered_map<ir::IODesc, std::vector<ir::IODesc>> _edge_map; + + /** + * @brief Type-aware quantization layers for edges between executors + * + */ + // TODO Move variables related to type-aware quantization for edges into compilation stage + // TODO Replace PermuteLayer with backend::builtin::kernel::PermuteLayer + std::unordered_map<std::pair<ir::ModelIndex, ir::SubgraphIndex>, std::unique_ptr<PermuteLayer>> + _edge_quant_layers; + + /** + * @brief Tensors for type-aware quantization of edges + * Key: `to` IODesc, Value: EdgeTensor + */ + // + // Q: Why is Key `to` IODesc + // A: these tensors are currently created depending on the type of `to` + // TODO Unify tensors with the same `from` tensor and same type + // NOTE The incomplete type 'EdgeTensor' cannot be declared as unique_ptr. + std::unordered_map<ir::IODesc, std::shared_ptr<EdgeTensor>> _edge_quant_tensors; + + /** + * @brief Tensors for edges between executors that are not related to type-aware quantization + * Key: `from` IODesc, Value: EdgeTensor + */ + // Q: Why is Key `from` IODesc + // A: `from` can be connected to multiple `to` + // NOTE The incomplete type 'EdgeTensor' cannot be declared as unique_ptr. + std::unordered_map<ir::IODesc, std::shared_ptr<EdgeTensor>> _edge_tensors; + /** + * @brief Whether type-aware quantization layers for edges between executors are created + * + */ + // TODO Remove this member after the creation of type-aware quantization layers for edges + // is moved into compilation stage + bool _is_created_edge_quant_layers; + + // TODO Replace PermuteLayer with backend::builtin::kernel::PermuteLayer + std::unordered_map<std::pair<ir::ModelIndex, ir::SubgraphIndex>, std::unique_ptr<PermuteLayer>> + _pkg_input_quant_layers; + // TODO Replace PermuteLayer with backend::builtin::kernel::PermuteLayer + std::unordered_map<std::pair<ir::ModelIndex, ir::SubgraphIndex>, std::unique_ptr<PermuteLayer>> + _pkg_output_quant_layers; + // Edge tensors of nnpkg inputs/outputs for type-aware quantization + std::unordered_map<ir::IODesc, std::shared_ptr<EdgeTensor>> _pkg_input_quant_tensors; + std::unordered_map<ir::IODesc, std::shared_ptr<EdgeTensor>> _pkg_output_quant_tensors; + // IOTensors for user buffer + std::unordered_map<ir::IODesc, std::unique_ptr<backend::builtin::UserTensor>> _pkg_input_tensors; + std::unordered_map<ir::IODesc, std::unique_ptr<backend::builtin::UserTensor>> _pkg_output_tensors; +}; + +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_EXECUTORS_H__ diff --git a/runtime/onert/core/src/exec/ParallelExecutor.cc b/runtime/onert/core/src/exec/ParallelExecutor.cc index ab234aacd..152fa7cd3 100644 --- a/runtime/onert/core/src/exec/ParallelExecutor.cc +++ b/runtime/onert/core/src/exec/ParallelExecutor.cc @@ -31,7 +31,7 @@ class HookFunction : public IFunction public: HookFunction(IFunction *fn, const std::function<void()> &setup, const std::function<void()> &teardown) - : _fn{fn}, _setup{setup}, _teardown{teardown} + : _fn{fn}, _setup{setup}, _teardown{teardown} { } @@ -59,29 +59,28 @@ void ParallelExecutor::notify(uint32_t finished_job_id) _cv_jobs.notify_all(); } -ParallelExecutor::ParallelExecutor( - std::unique_ptr<compiler::LoweredGraph> lowered_graph, - const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors, - const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors, - const compiler::TensorRegistries &tensor_regs, backend::TensorManagerSet &&tensor_mgrs, - compiler::CodeMap &&code_map) - : DataflowExecutor{std::move(lowered_graph), input_tensors, output_tensors, tensor_regs, - std::move(tensor_mgrs), std::move(code_map)} +ParallelExecutor::ParallelExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, + backend::BackendContexts &&backend_contexts, + const compiler::TensorRegistries &tensor_regs, + compiler::CodeMap &&code_map, + const util::TracingCtx *tracing_ctx) + : DataflowExecutor{std::move(lowered_graph), std::move(backend_contexts), tensor_regs, + std::move(code_map), tracing_ctx} { VERBOSE(ParallelExecutor) << "Constructing Parallel Executor" << std::endl; } -void ParallelExecutor::executeImpl() +void ParallelExecutor::executeImpl(const ExecutionObservee &subject) { bool dynamic_input_exists = hasDynamicInput(); // Init scheduler - // TODO Consider to have distinct backend set in LowerInfoMap + // TODO Consider to have distinct backend set in GraphLowerInfo BackendSet backends; - for (auto &itr : _lowered_graph->getLowerInfo()->op_seq) - { - backends.add(itr.second->backend()); - } + _lowered_graph->lower_info().operation.iterate( + [&](const ir::OperationIndex &, const compiler::OperationLowerInfo &lower_info) { + backends.add(lower_info.backend()); + }); _scheduler = std::make_unique<ParallelScheduler>(backends); assert(noWaitingJobs()); @@ -101,7 +100,10 @@ void ParallelExecutor::executeImpl() VERBOSE(ParallelExecutor) << "INITIAL JOBS : " << _ready_jobs.size() << std::endl; - _subject.notifyModelBegin(this); + auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_graph); + + subject.notifySubgraphBegin(profiling_subg_index); + while (true) { std::unique_lock<std::mutex> lock{_mu_jobs}; @@ -121,20 +123,24 @@ void ParallelExecutor::executeImpl() lock.unlock(); - VERBOSE(ParallelExecutor) << "Assigning fn #" << job->index() << std::endl; + VERBOSE(ParallelExecutor) << "Assigning fn " << job->index() << std::endl; auto job_index = job->index(); - auto op_sequence_index = _job_to_op_seq[job_index]; - auto op_seq = &_lowered_graph->op_seqs().at(op_sequence_index); - auto backend = _lowered_graph->getLowerInfo()->op_seq.at(op_sequence_index)->backend(); - auto setup = [&, op_seq, backend]() { _subject.notifyJobBegin(this, op_seq, backend); }; - auto teardown = [&, job_index, op_seq, backend]() { - _subject.notifyJobEnd(this, op_seq, backend); + auto op_ind = _job_to_op[job_index]; + auto backend = _lowered_graph->lower_info().operation.at(op_ind).backend(); + auto setup = [&, op_ind, backend]() { + subject.notifyJobBegin(this, profiling_subg_index, op_ind, backend); + }; + auto teardown = [&, job_index, op_ind, backend]() { + subject.notifyJobEnd(this, profiling_subg_index, op_ind, backend); notify(job_index); }; + job->fn_seq()->initRunning(); + // dynamic tensor setting - bool handle_dynamic_tensor = op_seq->has_dynamic_tensor() || dynamic_input_exists; + bool handle_dynamic_tensor = + _lowered_graph->getHasDynamicTensor(op_ind) || dynamic_input_exists; job->fn_seq()->enableDynamicShapeInferer(handle_dynamic_tensor); _scheduler->assign(std::make_unique<HookFunction>(job->fn_seq(), setup, teardown), backend); @@ -145,7 +151,7 @@ void ParallelExecutor::executeImpl() // Wait for all the jobs done _scheduler->finish(); - _subject.notifyModelEnd(this); + subject.notifySubgraphEnd(profiling_subg_index); // Reset input info for the next execution _input_info = _initial_input_info; diff --git a/runtime/onert/core/src/exec/ParallelExecutor.h b/runtime/onert/core/src/exec/ParallelExecutor.h index 929edfce9..3162d865f 100644 --- a/runtime/onert/core/src/exec/ParallelExecutor.h +++ b/runtime/onert/core/src/exec/ParallelExecutor.h @@ -17,17 +17,12 @@ #ifndef __ONERT_EXEC_PARALLEL_EXECUTOR_H__ #define __ONERT_EXEC_PARALLEL_EXECUTOR_H__ -#include <list> -#include <queue> -#include <unordered_map> +#include "DataflowExecutor.h" +#include "ParallelScheduler.h" + +#include "util/TracingCtx.h" -#include "exec/FunctionSequence.h" -#include "Job.h" -#include "ir/OperandIndexSequence.h" -#include "ir/Index.h" #include <memory> -#include "exec/DataflowExecutor.h" -#include "ParallelScheduler.h" namespace onert { @@ -48,15 +43,14 @@ public: * * @param lowered_graph LoweredGraph object * @param tensor_builders Tensor builders that are currently used - * @param code_map OpSequence and its code map + * @param code_map @c ir::Operation and its code map */ ParallelExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, - const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors, - const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors, - const compiler::TensorRegistries &tensor_regs, - backend::TensorManagerSet &&tensor_mgrs, compiler::CodeMap &&code_map); + backend::BackendContexts &&backend_contexts, + const compiler::TensorRegistries &tensor_regs, compiler::CodeMap &&code_map, + const util::TracingCtx *tracing_ctx); - void executeImpl() override; + void executeImpl(const ExecutionObservee &subject) override; private: std::condition_variable _cv_jobs; diff --git a/runtime/onert/core/src/exec/ParallelScheduler.cc b/runtime/onert/core/src/exec/ParallelScheduler.cc index 70c9c3dd6..538945631 100644 --- a/runtime/onert/core/src/exec/ParallelScheduler.cc +++ b/runtime/onert/core/src/exec/ParallelScheduler.cc @@ -30,7 +30,7 @@ ParallelScheduler::ParallelScheduler(const BackendSet &backends) { assert(!backends.empty()); - for (auto backend : backends) + for (auto &&backend : backends) { _thread_pools[backend] = std::make_unique<ThreadPool>(); } @@ -45,7 +45,7 @@ void ParallelScheduler::assign(std::unique_ptr<IFunction> &&fn, const backend::B void ParallelScheduler::finish() { - for (auto &itr : _thread_pools) + for (auto &&itr : _thread_pools) { itr.second->finish(); } diff --git a/runtime/onert/core/src/exec/SingleModelExecutors.cc b/runtime/onert/core/src/exec/SingleModelExecutors.cc new file mode 100644 index 000000000..44c5e5742 --- /dev/null +++ b/runtime/onert/core/src/exec/SingleModelExecutors.cc @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "SingleModelExecutors.h" + +#include "EdgeTensor.h" +#include "IPermuteFunction.h" +#include "../backend/builtin/UserTensor.h" + +namespace onert +{ +namespace exec +{ + +void SingleModelExecutors::emplace(const ir::ModelIndex &, const ir::SubgraphIndex &subg_index, + std::unique_ptr<IExecutor> exec) +{ + _executors.emplace(subg_index, std::move(exec)); +} + +IExecutor *SingleModelExecutors::at(const ir::ModelIndex &, + const ir::SubgraphIndex &subg_index) const +{ + return _executors.at(subg_index).get(); +} + +uint32_t SingleModelExecutors::inputSize() const { return entryExecutor()->inputSize(); } + +uint32_t SingleModelExecutors::outputSize() const { return entryExecutor()->outputSize(); } + +const ir::OperandInfo &SingleModelExecutors::inputInfo(const ir::IOIndex &index) const +{ + return entryExecutor()->inputInfo(index.value()); +} + +const ir::OperandInfo &SingleModelExecutors::outputInfo(const ir::IOIndex &index) const +{ + return entryExecutor()->outputInfo(index.value()); +} + +void SingleModelExecutors::execute(const ExecutionContext &ctx) +{ + // UserTensor for Input/Output + std::vector<std::unique_ptr<backend::builtin::UserTensor>> tensorpool; + + // EdgeTensor for Input Quantization / Output Dequantization + std::vector<std::unique_ptr<EdgeTensor>> qtensorpool; + + // Input/Output Tensor vector for executor + std::vector<backend::IPortableTensor *> inputs(ctx.desc.inputs.size()); + std::vector<backend::IPortableTensor *> outputs(ctx.desc.outputs.size()); + + // Vector for input quantization I/O + std::vector<backend::ITensor *> input_tensors; + std::vector<backend::ITensor *> input_qtensors; + + // Vector for output dequantization I/O + std::vector<backend::ITensor *> output_qtensors; + std::vector<backend::ITensor *> output_tensors; + + // Prepare UserTensor and EdgeTensor for input quantization + for (uint32_t i = 0; i < inputs.size(); i++) + { + auto &desc = ctx.desc.inputs[i]; + + // Input is optional if buffer is nullptr, and optional input's size is 0 + if (desc->buffer == nullptr && (desc->size != 0 || desc->info.total_size() != 0)) + throw std::runtime_error{"Input " + std::to_string(i) + "'s buffer is not set."}; + + tensorpool.emplace_back(std::make_unique<backend::builtin::UserTensor>( + desc->info, desc->layout, const_cast<uint8_t *>(static_cast<const uint8_t *>(desc->buffer)), + desc->size)); + + auto user_type = desc->info.typeInfo().type(); + auto &model_info = entryExecutor()->inputInfo(i).typeInfo(); + auto model_type = model_info.type(); + if (user_type != model_type && user_type == ir::DataType::FLOAT32) + { + auto quantized_info = desc->info; + quantized_info.typeInfo(model_info); + qtensorpool.emplace_back( + std::make_unique<EdgeTensor>(quantized_info, entryExecutor()->inputLayout(i))); + qtensorpool.back()->allocate_buffer(); + + input_tensors.push_back(tensorpool.back().get()); + input_qtensors.push_back(qtensorpool.back().get()); + inputs[i] = qtensorpool.back().get(); + } + else + inputs[i] = tensorpool.back().get(); + } + + // Prepare UserTensor and EdgeTensor for output dequantization + for (uint32_t i = 0; i < outputs.size(); i++) + { + auto &desc = ctx.desc.outputs[i]; + + // Output is optional if buffer is nullptr, and optional output's size is 0 + if (desc->buffer == nullptr && (desc->size != 0 || desc->info.total_size() != 0)) + throw std::runtime_error{"Output " + std::to_string(i) + "'s buffer is not set."}; + + tensorpool.emplace_back(std::make_unique<backend::builtin::UserTensor>( + desc->info, desc->layout, static_cast<uint8_t *>(desc->buffer), desc->size)); + + auto user_type = desc->info.typeInfo().type(); + auto &model_info = entryExecutor()->outputInfo(i).typeInfo(); + auto model_type = model_info.type(); + if (user_type != model_type && user_type == ir::DataType::FLOAT32) + { + auto quantized_info = desc->info; + quantized_info.typeInfo(model_info); + qtensorpool.emplace_back( + std::make_unique<EdgeTensor>(quantized_info, entryExecutor()->outputLayout(i))); + qtensorpool.back()->allocate_buffer(); + + output_qtensors.push_back(qtensorpool.back().get()); + output_tensors.push_back(tensorpool.back().get()); + outputs[i] = qtensorpool.back().get(); + } + else + outputs[i] = tensorpool.back().get(); + } + + // Run quantization + if (input_tensors.size() > 0) + { + auto input_quantize_layer = PermuteLayer(input_tensors, input_qtensors); + input_quantize_layer.prepare(); + input_quantize_layer.run(); + } + + // Executor + entryExecutor()->execute(inputs, outputs, ctx.options); + + // Run dequantization + if (output_tensors.size() != 0) + { + auto output_dequantize_layer = PermuteLayer(output_qtensors, output_tensors); + output_dequantize_layer.prepare(); + output_dequantize_layer.run(); + } + + // Get dynamic shape inference result + for (uint32_t i = 0; i < outputs.size(); i++) + { + if (ctx.desc.outputs[i]->buffer == nullptr) + { + // Output is optional if buffer is nullptr + continue; + } + + ctx.desc.outputs[i]->info.shape(outputs[i]->getShape()); + } +} + +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/SingleModelExecutors.h b/runtime/onert/core/src/exec/SingleModelExecutors.h new file mode 100644 index 000000000..66dce6077 --- /dev/null +++ b/runtime/onert/core/src/exec/SingleModelExecutors.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_SINGLE_MODEL_EXECUTORS_H__ +#define __ONERT_EXEC_SINGLE_MODEL_EXECUTORS_H__ + +#include "exec/IExecutors.h" +#include "ir/NNPkg.h" + +namespace onert +{ +namespace exec +{ + +/** + * @brief Class to gather executor set for single model NN package + */ +class SingleModelExecutors : public IExecutors +{ +public: + /** + * @brief Construct a new SingleModelExecutors object + */ + SingleModelExecutors(void) = default; + SingleModelExecutors(const SingleModelExecutors &) = delete; + SingleModelExecutors(SingleModelExecutors &&) = default; + + /** + * @brief Destroy the SingleModelExecutors object + */ + ~SingleModelExecutors() = default; + +public: + void emplace(const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index, + std::unique_ptr<IExecutor> exec) override; + + IExecutor *at(const ir::ModelIndex &model_index, + const ir::SubgraphIndex &subg_index) const override; + + uint32_t inputSize() const override; + + uint32_t outputSize() const override; + + const ir::OperandInfo &inputInfo(const ir::IOIndex &index) const override; + + const ir::OperandInfo &outputInfo(const ir::IOIndex &index) const override; + + void execute(const ExecutionContext &ctx) override; + +private: + std::unordered_map<ir::SubgraphIndex, std::unique_ptr<IExecutor>> _executors; +}; + +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_SINGLE_MODEL_EXECUTORS_H__ diff --git a/runtime/onert/core/src/exec/Sink.h b/runtime/onert/core/src/exec/Sink.h deleted file mode 100644 index 6a99efe60..000000000 --- a/runtime/onert/core/src/exec/Sink.h +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __ONERT_EXEC_SINK_H__ -#define __ONERT_EXEC_SINK_H__ - -#include "feature/nchw/Reader.h" -#include "feature/nchw/View.h" -#include "feature/nhwc/Reader.h" -#include "feature/nhwc/View.h" - -#include <cassert> -#include <memory> -#include "util/Utils.h" -#include <misc/feature/IndexIterator.h> - -namespace onert -{ -namespace exec -{ -struct ISink -{ - virtual ~ISink() = default; - - virtual void pull(::onert::backend::ITensor &tensor) const = 0; -}; - -// Create second lever inheritance: the first lever is used as a reference type in use-case places -template <typename T> class ITemplSink : public ISink -{ -public: - ITemplSink(void *output_buffer, const size_t &output_size, const ir::Shape &shape, - const bool copy, ir::Layout io_layout) - : _output_buffer{reinterpret_cast<T *>(output_buffer)}, _output_size{output_size}, - _shape{shape}, _copy{copy}, _io_layout{io_layout} - { - } - -protected: - void pullUnif(onert::backend::ITensor &tensor) const - { - assert(((_io_layout == ir::Layout::NHWC && tensor.layout() == ir::Layout::NCHW) || - (_io_layout == ir::Layout::NCHW && tensor.layout() == ir::Layout::NHWC)) || - _copy); - auto input_buffer = tensor.buffer(); - auto rank = _shape.rank(); - - if (!tensor.has_padding() && rank < 4 + _copy) - { - memcpy(_output_buffer, input_buffer, _output_size); - return; - } - - switch (rank) - { - case 0: - case 1: - { - memcpy(_output_buffer, input_buffer, _output_size); - break; - } - case 2: - { - const int32_t copy_len = _shape.dim(1); - - for (auto i = 0; i < _shape.dim(0); ++i) - { - ir::Coordinates coords{i, 0}; - memcpy(_output_buffer + i * copy_len, input_buffer + tensor.calcOffset(coords), - copy_len * sizeof(T)); - } - break; - } - case 3: - { - const int32_t dim1 = _shape.dim(1); - const int32_t dim2 = _shape.dim(2); - - for (auto i = 0; i < _shape.dim(0); ++i) - { - for (auto j = 0; j < _shape.dim(1); ++j) - { - ir::Coordinates coords{i, j, 0}; - memcpy(_output_buffer + i * dim1 * dim2 + j * dim2, - input_buffer + tensor.calcOffset(coords), dim2 * sizeof(T)); - } - } - break; - } - case 4: - { - if (_copy) - { - const int32_t dim1 = _shape.dim(1); - const int32_t dim2 = _shape.dim(2); - const int32_t dim3 = _shape.dim(3); - - for (auto i = 0; i < _shape.dim(0); ++i) - { - for (auto j = 0; j < _shape.dim(1); ++j) - { - for (auto k = 0; k < _shape.dim(2); ++k) - { - ir::Coordinates coords{i, j, k, 0}; - memcpy(_output_buffer + i * dim1 * dim2 * dim3 + j * dim2 * dim3 + k * dim3, - input_buffer + tensor.calcOffset(coords), dim3 * sizeof(T)); - } - } - } - } - else - { - const auto shape = _shape.asFeature(_io_layout); - - if (_io_layout == ir::Layout::NHWC) - { - const exec::feature::nchw::Reader<T> from(&tensor); - exec::feature::nhwc::View<T> into(shape, _output_buffer, _output_size); - feature::iterate(shape) - << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) { - const auto value = from.at(batch, ch, row, col); - into.at(batch, row, col, ch) = value; - }; - } - else if (_io_layout == ir::Layout::NCHW) - { - const exec::feature::nhwc::Reader<T> from(&tensor); - exec::feature::nchw::View<T> into(shape, _output_buffer, _output_size); - feature::iterate(shape) - << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) { - const auto value = from.at(batch, row, col, ch); - into.at(batch, ch, row, col) = value; - }; - } - else - { - throw std::runtime_error("Wrong Layout"); - } - } - break; - } - default: - throw std::runtime_error("NYI: rank > 4"); - break; - } - } - -private: - T *_output_buffer; - const size_t _output_size; - const ir::Shape _shape; - const bool _copy; - const ir::Layout _io_layout; -}; - -template <typename T> class PermutateSink final : public ITemplSink<T> -{ -public: - PermutateSink(void *output_buffer, const size_t &output_size, const ir::Shape &shape, - ir::Layout io_layout) - : ITemplSink<T>(output_buffer, output_size, shape, false, io_layout) - { - } - -public: - void pull(onert::backend::ITensor &tensor) const override { ITemplSink<T>::pullUnif(tensor); } -}; - -// Only supports NHWC format front-end(NNAPI) now -template <typename T> class CopySink final : public ITemplSink<T> -{ -public: - CopySink(void *output_buffer, const size_t &output_size, const ir::Shape &shape, - ir::Layout io_layout = ir::Layout::UNKNOWN) - : ITemplSink<T>(output_buffer, output_size, shape, true, io_layout) - { - } - -public: - void pull(onert::backend::ITensor &tensor) const override { ITemplSink<T>::pullUnif(tensor); } -}; - -} // namespace exec -} // namespace onert - -#endif // __ONERT_EXEC_SINK_H__ diff --git a/runtime/onert/core/src/exec/Source.h b/runtime/onert/core/src/exec/Source.h deleted file mode 100644 index fb2be4dd8..000000000 --- a/runtime/onert/core/src/exec/Source.h +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __ONERT_EXEC_SOURCE_H__ -#define __ONERT_EXEC_SOURCE_H__ - -#include "feature/IndexIterator.h" -#include "feature/nchw/Reader.h" -#include "feature/nchw/View.h" -#include "feature/nhwc/Reader.h" -#include "feature/nhwc/View.h" - -#include <cassert> -#include <memory> -#include "util/Utils.h" -#include <ir/Layout.h> -#include "ir/Shape.h" - -namespace onert -{ -namespace exec -{ - -struct ISource -{ - virtual ~ISource() = default; - - virtual void push(::onert::backend::ITensor &tensor) const = 0; -}; - -// Create second lever inheritance: the first lever is used as a reference type in use-case places -template <typename T> class ITemplSource : public ISource -{ -public: - ITemplSource(const void *input_buffer, const size_t &input_size, const ir::Shape &shape, - const bool copy, ir::Layout io_layout) - : _input_buffer{reinterpret_cast<const T *>(input_buffer)}, _input_size{input_size}, - _shape{shape}, _copy(copy), _io_layout{io_layout} - { - } - - virtual void push(::onert::backend::ITensor &tensor) const = 0; - -protected: - void pushUnif(onert::backend::ITensor &tensor) const - { - assert(((_io_layout == ir::Layout::NHWC && tensor.layout() == ir::Layout::NCHW) || - (_io_layout == ir::Layout::NCHW && tensor.layout() == ir::Layout::NHWC)) || - _copy); - auto output_buffer = tensor.buffer(); - auto rank = _shape.rank(); - - if (!tensor.has_padding() && rank < 4 + _copy) - { - memcpy(output_buffer, _input_buffer, _input_size); - return; - } - - switch (rank) - { - case 0: - case 1: - { - memcpy(output_buffer, _input_buffer, _input_size); - break; - } - case 2: - { - const int32_t copy_len = _shape.dim(1); - - for (auto i = 0; i < _shape.dim(0); ++i) - { - ir::Coordinates coords{i, 0}; - memcpy(output_buffer + tensor.calcOffset(coords), _input_buffer + i * copy_len, - copy_len * sizeof(T)); - } - break; - } - case 3: - { - const int32_t dim1 = _shape.dim(1); - const int32_t dim2 = _shape.dim(2); - - for (auto i = 0; i < _shape.dim(0); ++i) - { - for (auto j = 0; j < _shape.dim(1); ++j) - { - ir::Coordinates coords{i, j, 0}; - memcpy(output_buffer + tensor.calcOffset(coords), - _input_buffer + i * dim1 * dim2 + j * dim2, dim2 * sizeof(T)); - } - } - break; - } - case 4: - { - if (_copy) - { - const int32_t dim1 = _shape.dim(1); - const int32_t dim2 = _shape.dim(2); - const int32_t dim3 = _shape.dim(3); - for (auto i = 0; i < _shape.dim(0); ++i) - { - for (auto j = 0; j < _shape.dim(1); ++j) - { - for (auto k = 0; k < _shape.dim(2); ++k) - { - ir::Coordinates coords{i, j, k, 0}; - memcpy(output_buffer + tensor.calcOffset(coords), - _input_buffer + i * dim1 * dim2 * dim3 + j * dim2 * dim3 + k * dim3, - dim3 * sizeof(T)); - } - } - } - } - else - { - const auto shape = _shape.asFeature(_io_layout); - - if (_io_layout == ir::Layout::NCHW) - { - const exec::feature::nchw::Reader<T> from(shape, _input_buffer, _input_size); - exec::feature::nhwc::View<T> into(&tensor); - feature::iterate(shape) - << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) { - const auto value = from.at(batch, ch, row, col); - into.at(batch, row, col, ch) = value; - }; - } - else if (_io_layout == ir::Layout::NHWC) - { - const exec::feature::nhwc::Reader<T> from(shape, _input_buffer, _input_size); - exec::feature::nchw::View<T> into(&tensor); - feature::iterate(shape) - << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) { - const auto value = from.at(batch, row, col, ch); - into.at(batch, ch, row, col) = value; - }; - } - else - { - throw std::runtime_error("Wrong Layout"); - } - } - - break; - } - default: - throw std::runtime_error("NYI: rank > 4"); - break; - } - } - -private: - const T *_input_buffer; - const size_t _input_size; - const ir::Shape _shape; - const bool _copy; - const ir::Layout _io_layout; -}; - -template <typename T> class PermutateSource final : public ITemplSource<T> -{ -public: - PermutateSource(const void *input_buffer, const size_t &input_size, const ir::Shape &shape, - ir::Layout io_layout) - : ITemplSource<T>(input_buffer, input_size, shape, false, io_layout) - { - } - -public: - void push(onert::backend::ITensor &tensor) const override - { - // do NHWC_TO_NCHW or NCHW_TO_NHWC permutation - ITemplSource<T>::pushUnif(tensor); - } -}; - -template <typename T> class CopySource final : public ITemplSource<T> -{ -public: - CopySource(const void *input_buffer, const size_t &input_size, const ir::Shape &shape, - ir::Layout io_layout = ir::Layout::UNKNOWN) - : ITemplSource<T>(input_buffer, input_size, shape, true, io_layout) - { - } - -public: - void push(onert::backend::ITensor &tensor) const override { ITemplSource<T>::pushUnif(tensor); } -}; - -} // namespace exec -} // namespace onert - -#endif // __ONERT_EXEC_SOURCE_H__ diff --git a/runtime/onert/core/src/exec/ThreadPool.cc b/runtime/onert/core/src/exec/ThreadPool.cc index c8e0e3265..bf85e59f6 100644 --- a/runtime/onert/core/src/exec/ThreadPool.cc +++ b/runtime/onert/core/src/exec/ThreadPool.cc @@ -48,7 +48,7 @@ uint32_t ThreadPool::numJobsInQueue() { return _worker.numJobsInQueue(); } void ThreadPool::join() { - for (auto &thread : _threads) + for (auto &&thread : _threads) { thread.join(); } diff --git a/runtime/onert/core/src/exec/feature/MockTensor.test.h b/runtime/onert/core/src/exec/feature/MockTensor.test.h new file mode 100644 index 000000000..1d2d375e2 --- /dev/null +++ b/runtime/onert/core/src/exec/feature/MockTensor.test.h @@ -0,0 +1,66 @@ + +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/ITensor.h" + +template <typename T> class MockTensor : public onert::backend::ITensor +{ +public: + MockTensor<T>(onert::ir::Shape &shape, T *buf, onert::ir::Layout layout) + : _buf(reinterpret_cast<uint8_t *>(buf)), _shape(shape), _layout(layout) + { + } + +public: + uint8_t *buffer() const override { return _buf; } + + size_t calcOffset(const onert::ir::Coordinates &coords) const override + { + size_t rank = _shape.rank(); + rank = rank == 0 ? 1 : rank; + size_t offset = 0; + for (size_t i = 0; i < rank; ++i) + { + auto dim = _shape.rank() == 0 ? 1 : _shape.dim(i); + offset = offset * dim + coords[i]; + } + offset *= sizeof(T); + + return offset; + } + + onert::ir::Shape getShape() const override { return _shape; } + +public: // DUMMY methods + size_t total_size() const override { return 0; } + onert::ir::Layout layout() const override { return _layout; } + onert::ir::DataType data_type() const override { return onert::ir::DataType::UINT8; } + float data_scale() const override { return 0; } + int32_t data_zero_point() const override { return 0; } + const std::vector<float> &data_scales() const override { return _dummy_scales; } + const std::vector<int32_t> &data_zero_points() const override { return _dummy_zerops; } + bool has_padding() const override { return false; } + void access(const std::function<void(ITensor &tensor)> &fn) override {} + bool is_dynamic() const override { return false; } + +private: + uint8_t *_buf = nullptr; + onert::ir::Shape _shape; + onert::ir::Layout _layout = onert::ir::Layout::UNKNOWN; + std::vector<float> _dummy_scales; + std::vector<int32_t> _dummy_zerops; +}; diff --git a/runtime/onert/core/src/exec/feature/nchw/Reader.h b/runtime/onert/core/src/exec/feature/nchw/Reader.h index 7be9df4d5..e1a963cbd 100644 --- a/runtime/onert/core/src/exec/feature/nchw/Reader.h +++ b/runtime/onert/core/src/exec/feature/nchw/Reader.h @@ -36,35 +36,35 @@ namespace nchw template <typename T> class Reader : public feature::Reader<T> { public: - // Construct for buffer of model inputs - Reader(const ir::FeatureShape &shape, const T *ptr, size_t len) - : _shape{shape}, _ptr{reinterpret_cast<const uint8_t *>(ptr)}, _len{len} + using Strides = ir::FeatureShape; + // Construct for buffer and strides + Reader(const ir::FeatureShape &shape, const Strides &strides, const T *ptr, size_t len) + : _shape{shape}, _strides{strides}, _ptr{reinterpret_cast<const uint8_t *>(ptr)}, _len{len} { - assert(shape.N * shape.C * shape.H * shape.W * sizeof(T) == len); - - // No padding - _strides.W = sizeof(T); - _strides.H = shape.W * sizeof(T); - _strides.C = shape.W * shape.H * sizeof(T); - _strides.N = shape.W * shape.H * shape.C * sizeof(T); + UNUSED_RELEASE(len); // Workaround for unused variable in release mode + assert(len == static_cast<size_t>(strides.N != 0 ? shape.N * strides.N + : strides.C != 0 ? shape.C * strides.C + : strides.H != 0 ? shape.H * strides.H + : shape.W * strides.W)); } // Construct for backend tensor Reader(backend::ITensor *tensor) - : _ptr{tensor->buffer() + tensor->calcOffset({0, 0, 0, 0})}, _len{tensor->total_size()} + : _ptr{tensor->buffer() + tensor->calcOffset({0, 0, 0, 0})}, _len{tensor->total_size()} { assert(tensor->layout() == ir::Layout::NCHW); const auto start_offset = tensor->calcOffset({0, 0, 0, 0}); - _strides.W = tensor->dimension(3) == 1 ? 0 : tensor->calcOffset({0, 0, 0, 1}) - start_offset; - _strides.H = tensor->dimension(2) == 1 ? 0 : tensor->calcOffset({0, 0, 1, 0}) - start_offset; - _strides.C = tensor->dimension(1) == 1 ? 0 : tensor->calcOffset({0, 1, 0, 0}) - start_offset; - _strides.N = tensor->dimension(0) == 1 ? 0 : tensor->calcOffset({1, 0, 0, 0}) - start_offset; - - _shape.W = tensor->dimension(3); - _shape.H = tensor->dimension(2); - _shape.C = tensor->dimension(1); - _shape.N = tensor->dimension(0); + auto shape = tensor->getShape(); + _strides.W = shape.dim(3) == 1 ? 0 : tensor->calcOffset({0, 0, 0, 1}) - start_offset; + _strides.H = shape.dim(2) == 1 ? 0 : tensor->calcOffset({0, 0, 1, 0}) - start_offset; + _strides.C = shape.dim(1) == 1 ? 0 : tensor->calcOffset({0, 1, 0, 0}) - start_offset; + _strides.N = shape.dim(0) == 1 ? 0 : tensor->calcOffset({1, 0, 0, 0}) - start_offset; + + _shape.W = shape.dim(3); + _shape.H = shape.dim(2); + _shape.C = shape.dim(1); + _shape.N = shape.dim(0); } public: @@ -104,7 +104,6 @@ private: private: // TODO Remove _shape ir::FeatureShape _shape; - using Strides = ir::FeatureShape; Strides _strides; const uint8_t *_ptr; size_t _len; diff --git a/runtime/onert/core/src/exec/feature/nchw/Reader.test.cc b/runtime/onert/core/src/exec/feature/nchw/Reader.test.cc new file mode 100644 index 000000000..c405190f7 --- /dev/null +++ b/runtime/onert/core/src/exec/feature/nchw/Reader.test.cc @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Reader.h" + +#include "../MockTensor.test.h" + +#include <gtest/gtest.h> + +using namespace onert::exec::feature; + +template <typename T> class Reader_nchw : public testing::Test +{ +public: + void setData(std::initializer_list<T> list) { _data = std::make_shared<std::vector<T>>(list); } + + void setShape(int32_t batch, int32_t depth, int32_t height, int32_t width) + { + _shape = onert::ir::FeatureShape(batch, depth, height, width); + } + + void setStride(int32_t batch, int32_t depth, int32_t height, int32_t width) + { + auto elem_size = sizeof(T); + _stride = onert::ir::FeatureShape(batch * elem_size, depth * elem_size, height * elem_size, + width * elem_size); + } + + void createReader() + { + _reader = + std::make_shared<nchw::Reader<T>>(_shape, _stride, _data->data(), _data->size() * sizeof(T)); + } + + void createUsingMockTensor() + { + onert::ir::Shape shape = {_shape.N, _shape.H, _shape.W, _shape.C}; + _tensor = std::make_shared<MockTensor<T>>(shape, _data->data(), onert::ir::Layout::NCHW); + _reader = std::make_shared<nchw::Reader<T>>(_tensor.get()); + } + + std::shared_ptr<Reader<T>> _reader = nullptr; + +private: + std::shared_ptr<std::vector<T>> _data = nullptr; + onert::ir::FeatureShape _shape; + onert::ir::FeatureShape _stride; + std::shared_ptr<MockTensor<T>> _tensor = nullptr; +}; + +using ReaderTypes = ::testing::Types<float, int32_t, uint8_t, int8_t, int16_t>; +TYPED_TEST_SUITE(Reader_nchw, ReaderTypes); + +TYPED_TEST(Reader_nchw, basic_reader) +{ + this->setData({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + this->setShape(1, 2, 3, 2); + this->setStride(12, 6, 2, 1); + this->createReader(); + + // Data: NCHW + // Shape: NCHW + ASSERT_EQ(this->_reader->at(0, 1, 1, 0), 8); + ASSERT_EQ(this->_reader->at(1, 1, 0), 8); + + // Data: NCHW + // Shape: NCHW + this->createUsingMockTensor(); + + ASSERT_EQ(this->_reader->at(0, 1, 1, 0), 6); + ASSERT_EQ(this->_reader->at(1, 1, 0), 6); +} diff --git a/runtime/onert/core/src/exec/feature/nchw/View.h b/runtime/onert/core/src/exec/feature/nchw/View.h index dbaf1a91e..cdbb0cd7c 100644 --- a/runtime/onert/core/src/exec/feature/nchw/View.h +++ b/runtime/onert/core/src/exec/feature/nchw/View.h @@ -37,8 +37,10 @@ namespace nchw template <typename T> class View final : public Reader<T> { public: + using Strides = typename Reader<T>::Strides; // Construct for buffer of model inputs - View(const ir::FeatureShape &shape, T *ptr, size_t len) : Reader<T>{shape, ptr, len} + View(const ir::FeatureShape &shape, const Strides &strides, T *ptr, size_t len) + : Reader<T>{shape, strides, ptr, len} { // DO NOTHING } diff --git a/runtime/onert/core/src/exec/feature/nchw/View.test.cc b/runtime/onert/core/src/exec/feature/nchw/View.test.cc new file mode 100644 index 000000000..d21a8b784 --- /dev/null +++ b/runtime/onert/core/src/exec/feature/nchw/View.test.cc @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "View.h" + +#include "../MockTensor.test.h" + +#include <gtest/gtest.h> + +using namespace onert::exec::feature; + +template <typename T> class View_nchw : public testing::Test +{ +public: + void setData(std::initializer_list<T> list) { _data = std::make_shared<std::vector<T>>(list); } + + void setShape(int32_t batch, int32_t depth, int32_t height, int32_t width) + { + _shape = onert::ir::FeatureShape(batch, depth, height, width); + } + + void setStride(int32_t batch, int32_t depth, int32_t height, int32_t width) + { + auto elem_size = sizeof(T); + _stride = onert::ir::FeatureShape(batch * elem_size, depth * elem_size, height * elem_size, + width * elem_size); + } + + void createView() + { + _view = + std::make_shared<nchw::View<T>>(_shape, _stride, _data->data(), _data->size() * sizeof(T)); + } + + void createUsingMockTensor() + { + onert::ir::Shape shape = {_shape.N, _shape.H, _shape.W, _shape.C}; + _tensor = std::make_shared<MockTensor<T>>(shape, _data->data(), onert::ir::Layout::NCHW); + _view = std::make_shared<nchw::View<T>>(_tensor.get()); + } + + std::shared_ptr<nchw::View<T>> _view = nullptr; + +private: + std::shared_ptr<std::vector<T>> _data = nullptr; + onert::ir::FeatureShape _shape; + onert::ir::FeatureShape _stride; + std::shared_ptr<MockTensor<T>> _tensor = nullptr; +}; + +using ViewTypes = ::testing::Types<float, int32_t, uint8_t, int8_t, int16_t>; +TYPED_TEST_SUITE(View_nchw, ViewTypes); + +TYPED_TEST(View_nchw, basic_view) +{ + this->setData({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + this->setShape(1, 2, 3, 2); + this->setStride(12, 6, 2, 1); + this->createView(); + + // Data: NCHW + // Shape: NCHW + ASSERT_EQ(this->_view->at(0, 1, 1, 0), 8); + ASSERT_EQ(this->_view->at(1, 1, 0), 8); + + // Data: NCHW + // Shape: NCHW + this->createUsingMockTensor(); + + ASSERT_EQ(this->_view->at(0, 1, 1, 0), 6); + ASSERT_EQ(this->_view->at(1, 1, 0), 6); +} diff --git a/runtime/onert/core/src/exec/feature/nhwc/Reader.h b/runtime/onert/core/src/exec/feature/nhwc/Reader.h index 7730cee72..3e3c431bf 100644 --- a/runtime/onert/core/src/exec/feature/nhwc/Reader.h +++ b/runtime/onert/core/src/exec/feature/nhwc/Reader.h @@ -37,36 +37,35 @@ namespace nhwc template <typename T> class Reader : public feature::Reader<T> { public: - // Construct for buffer of model inputs - Reader(const ir::FeatureShape &shape, const T *ptr, size_t len) - : _shape{shape}, _ptr{reinterpret_cast<const uint8_t *>(ptr)}, _len{len} + using Strides = ir::FeatureShape; + // Construct for buffer and strides + Reader(const ir::FeatureShape &shape, const Strides &strides, const T *ptr, size_t len) + : _shape{shape}, _strides{strides}, _ptr{reinterpret_cast<const uint8_t *>(ptr)}, _len{len} { UNUSED_RELEASE(len); // Workaround for unused variable in release mode - assert(shape.N * shape.C * shape.H * shape.W * sizeof(T) == len); - - // No padding - _strides.C = sizeof(T); - _strides.W = shape.C * sizeof(T); - _strides.H = shape.C * shape.W * sizeof(T); - _strides.N = shape.C * shape.W * shape.H * sizeof(T); + assert(len == static_cast<size_t>(strides.N != 0 ? shape.N * strides.N + : strides.H != 0 ? shape.H * strides.H + : strides.W != 0 ? shape.W * strides.W + : shape.C * strides.C)); } // Construct for backend tensor Reader(const backend::ITensor *tensor) - : _ptr{tensor->buffer() + tensor->calcOffset({0, 0, 0, 0})}, _len{tensor->total_size()} + : _ptr{tensor->buffer() + tensor->calcOffset({0, 0, 0, 0})}, _len{tensor->total_size()} { assert(tensor->layout() == ir::Layout::NHWC); const auto start_offset = tensor->calcOffset({0, 0, 0, 0}); - _strides.C = tensor->dimension(3) == 1 ? 0 : tensor->calcOffset({0, 0, 0, 1}) - start_offset; - _strides.W = tensor->dimension(2) == 1 ? 0 : tensor->calcOffset({0, 0, 1, 0}) - start_offset; - _strides.H = tensor->dimension(1) == 1 ? 0 : tensor->calcOffset({0, 1, 0, 0}) - start_offset; - _strides.N = tensor->dimension(0) == 1 ? 0 : tensor->calcOffset({1, 0, 0, 0}) - start_offset; - - _shape.C = tensor->dimension(3); - _shape.W = tensor->dimension(2); - _shape.H = tensor->dimension(1); - _shape.N = tensor->dimension(0); + auto shape = tensor->getShape(); + _strides.C = shape.dim(3) == 1 ? 0 : tensor->calcOffset({0, 0, 0, 1}) - start_offset; + _strides.W = shape.dim(2) == 1 ? 0 : tensor->calcOffset({0, 0, 1, 0}) - start_offset; + _strides.H = shape.dim(1) == 1 ? 0 : tensor->calcOffset({0, 1, 0, 0}) - start_offset; + _strides.N = shape.dim(0) == 1 ? 0 : tensor->calcOffset({1, 0, 0, 0}) - start_offset; + + _shape.C = shape.dim(3); + _shape.W = shape.dim(2); + _shape.H = shape.dim(1); + _shape.N = shape.dim(0); } public: @@ -106,7 +105,6 @@ private: private: // TODO Remove _shape ir::FeatureShape _shape; - using Strides = ir::FeatureShape; Strides _strides; const uint8_t *_ptr; size_t _len; diff --git a/runtime/onert/core/src/exec/feature/nhwc/Reader.test.cc b/runtime/onert/core/src/exec/feature/nhwc/Reader.test.cc new file mode 100644 index 000000000..1f3a4dd06 --- /dev/null +++ b/runtime/onert/core/src/exec/feature/nhwc/Reader.test.cc @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Reader.h" + +#include "../MockTensor.test.h" + +#include <gtest/gtest.h> + +using namespace onert::exec::feature; + +template <typename T> class Reader_nhwc : public testing::Test +{ +public: + void setData(std::initializer_list<T> list) { _data = std::make_shared<std::vector<T>>(list); } + + void setShape(int32_t batch, int32_t depth, int32_t height, int32_t width) + { + _shape = onert::ir::FeatureShape(batch, depth, height, width); + } + + void setStride(int32_t batch, int32_t depth, int32_t height, int32_t width) + { + auto elem_size = sizeof(T); + _stride = onert::ir::FeatureShape(batch * elem_size, depth * elem_size, height * elem_size, + width * elem_size); + } + + void createReader() + { + _reader = + std::make_shared<nhwc::Reader<T>>(_shape, _stride, _data->data(), _data->size() * sizeof(T)); + } + + void createUsingMockTensor() + { + onert::ir::Shape shape = {_shape.N, _shape.H, _shape.W, _shape.C}; + _tensor = std::make_shared<MockTensor<T>>(shape, _data->data(), onert::ir::Layout::NHWC); + _reader = std::make_shared<nhwc::Reader<T>>(_tensor.get()); + } + + std::shared_ptr<nhwc::Reader<T>> _reader = nullptr; + +private: + std::shared_ptr<std::vector<T>> _data = nullptr; + onert::ir::FeatureShape _shape; + onert::ir::FeatureShape _stride; + std::shared_ptr<MockTensor<T>> _tensor = nullptr; +}; + +using ReaderTypes = ::testing::Types<float, int32_t, uint8_t, int8_t, int16_t>; +TYPED_TEST_SUITE(Reader_nhwc, ReaderTypes); +TYPED_TEST_SUITE(MockTensorReader_nhwc, ReaderTypes); + +TYPED_TEST(Reader_nhwc, basic_reader) +{ + this->setData({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + this->setShape(1, 2, 3, 2); + this->setStride(12, 1, 6, 2); + this->createReader(); + + // Data: NCHW + // Shape: NHWC + ASSERT_EQ(this->_reader->at(0, 1, 1, 0), 8); + ASSERT_EQ(this->_reader->at(1, 1, 0), 8); + + // Data: NHWC + // Shape: NHWC + this->createUsingMockTensor(); + + ASSERT_EQ(this->_reader->at(0, 1, 1, 0), 6); + ASSERT_EQ(this->_reader->at(1, 1, 0), 6); +} diff --git a/runtime/onert/core/src/exec/feature/nhwc/View.h b/runtime/onert/core/src/exec/feature/nhwc/View.h index 72c8c3415..c98d050c3 100644 --- a/runtime/onert/core/src/exec/feature/nhwc/View.h +++ b/runtime/onert/core/src/exec/feature/nhwc/View.h @@ -17,7 +17,7 @@ #ifndef __ONERT_EXEC_FEATURE_NHWC_VIEW_H__ #define __ONERT_EXEC_FEATURE_NHWC_VIEW_H__ -#include "../Reader.h" +#include "Reader.h" #include <cassert> #include <cstddef> @@ -38,8 +38,10 @@ namespace nhwc template <typename T> class View final : public Reader<T> { public: - // Construct for buffer of model inputs - View(const ir::FeatureShape &shape, T *ptr, size_t len) : Reader<T>{shape, ptr, len} + using Strides = typename Reader<T>::Strides; + // Construct for buffer and strides + View(const ir::FeatureShape &shape, const Strides &strides, T *ptr, size_t len) + : Reader<T>{shape, strides, ptr, len} { // DO NOTHING } diff --git a/runtime/onert/core/src/exec/feature/nhwc/View.test.cc b/runtime/onert/core/src/exec/feature/nhwc/View.test.cc new file mode 100644 index 000000000..c9018660c --- /dev/null +++ b/runtime/onert/core/src/exec/feature/nhwc/View.test.cc @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "View.h" + +#include "../MockTensor.test.h" + +#include <gtest/gtest.h> + +using namespace onert::exec::feature; + +template <typename T> class View_nhwc : public testing::Test +{ +public: + void setData(std::initializer_list<T> list) { _data = std::make_shared<std::vector<T>>(list); } + + void setShape(int32_t batch, int32_t depth, int32_t height, int32_t width) + { + _shape = onert::ir::FeatureShape(batch, depth, height, width); + } + + void setStride(int32_t batch, int32_t depth, int32_t height, int32_t width) + { + auto elem_size = sizeof(T); + _stride = onert::ir::FeatureShape(batch * elem_size, depth * elem_size, height * elem_size, + width * elem_size); + } + + void createView() + { + _view = + std::make_shared<nhwc::View<T>>(_shape, _stride, _data->data(), _data->size() * sizeof(T)); + } + + void createUsingMockTensor() + { + onert::ir::Shape shape = {_shape.N, _shape.H, _shape.W, _shape.C}; + _tensor = std::make_shared<MockTensor<T>>(shape, _data->data(), onert::ir::Layout::NHWC); + _view = std::make_shared<nhwc::View<T>>(_tensor.get()); + } + + std::shared_ptr<nhwc::View<T>> _view = nullptr; + +private: + std::shared_ptr<std::vector<T>> _data = nullptr; + onert::ir::FeatureShape _shape; + onert::ir::FeatureShape _stride; + std::shared_ptr<MockTensor<T>> _tensor = nullptr; +}; + +using ViewTypes = ::testing::Types<float, int32_t, uint8_t, int8_t, int16_t>; +TYPED_TEST_SUITE(View_nhwc, ViewTypes); +TYPED_TEST_SUITE(MockTensorView_nhwc, ViewTypes); + +TYPED_TEST(View_nhwc, basic_view) +{ + this->setData({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + this->setShape(1, 2, 3, 2); + this->setStride(12, 1, 6, 2); + this->createView(); + + // Data: NCHW + // Shape: NHWC + ASSERT_EQ(this->_view->at(0, 1, 1, 0), 8); + ASSERT_EQ(this->_view->at(1, 1, 0), 8); + + // Data: NHWC + // Shape: NHWC + this->createUsingMockTensor(); + + ASSERT_EQ(this->_view->at(0, 1, 1, 0), 6); + ASSERT_EQ(this->_view->at(1, 1, 0), 6); +} diff --git a/runtime/onert/core/src/exec/train/TrainableExecutor.cc b/runtime/onert/core/src/exec/train/TrainableExecutor.cc new file mode 100644 index 000000000..5d7c4f3f7 --- /dev/null +++ b/runtime/onert/core/src/exec/train/TrainableExecutor.cc @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TrainableExecutor.h" +#ifdef RUY_PROFILER +#include "ruy/profiler/instrumentation.h" +#endif + +#include <misc/polymorphic_downcast.h> + +namespace onert +{ +namespace exec +{ +namespace train +{ + +TrainableExecutor::TrainableExecutor( + std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph, + backend::train::TrainableBackendContexts &&backend_contexts, + const compiler::train::TensorRegistries &tensor_regs, + compiler::train::TrainableCodeMap &&code_map, + const std::vector<ir::OperationIndex> &forward_order, + const std::vector<ir::OperationIndex> &backward_order, const util::TracingCtx *tracing_ctx, + const ir::train::LossInfo &loss_info) + : _code_map{std::move(code_map)}, _forward_order{std::move(forward_order)}, + _backward_order{std::move(backward_order)}, _lowered_graph{std::move(lowered_graph)}, + _backend_contexts{std::move(backend_contexts)}, + _trainable_graph{_lowered_graph->trainable_graph()}, _tensor_regs{std::move(tensor_regs)}, + _mutex(), _tracing_ctx(tracing_ctx), _loss_info(loss_info) +{ + auto build_tensor_list = [&](const auto &ind_seq, auto &tensors) { + assert(tensors.empty()); + for (auto &&ind : ind_seq) + { + backend::ITensor *tensor = _tensor_regs.getITensor(ind); + assert(tensor != nullptr); + auto io_tensor = nnfw::misc::polymorphic_downcast<backend::builtin::IOTensor *>(tensor); + tensors.push_back(io_tensor); + } + }; + build_tensor_list(_trainable_graph.getInputs(), _input_tensors); + build_tensor_list(_trainable_graph.getOutputs(), _output_tensors); +} + +void TrainableExecutor::forward(const std::vector<backend::IPortableTensor *> &inputs, + const std::vector<backend::IPortableTensor *> &outputs, + const ExecutionOptions &options, bool training) +{ + // For thread-safe, use mutex + // TODO: if all used backends on this executor are thread-safe, + // do not need to use mutex (otherwise, use mutex) + std::lock_guard<std::mutex> lock(_mutex); + _current_options = options; + + assert(_input_tensors.size() == inputs.size()); + for (uint32_t i = 0; i < _input_tensors.size(); ++i) + { + auto tensor = _input_tensors[i]; + const auto input = inputs[i]; + assert(input->buffer() != nullptr || input->get_info().total_size() == 0); + assert(tensor != nullptr); + tensor->setTensor(input); + } + + // Set output(s) + assert(_output_tensors.size() == outputs.size()); + for (uint32_t i = 0; i < _output_tensors.size(); ++i) + { + auto tensor = _output_tensors[i]; + const auto output = outputs[i]; + // Output may not be used on training, so don't check optional + assert(tensor != nullptr); + tensor->setTensor(output); + } + + // Create observee + ExecutionObservee subject(_observers, options); + + forwardImpl(subject, training); + + // TODO Update output(s) desc if desc has dynamic input +} + +void TrainableExecutor::forwardImpl(const ExecutionObservee &subject, bool training) +{ + if (!subject.isEmpty() && _tracing_ctx) + { + auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph()); + + subject.notifySubgraphBegin(profiling_subg_index); + for (auto &&index : _forward_order) + { + const auto &code = _code_map.at(index); + const auto backend = code.lower_info->backend(); +// TODO : Move ruy profiler into ExecutionObserver +#ifdef RUY_PROFILER + ruy::profiler::ScopeLabel label(code.op->name()); +#endif + subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend); + + auto &tn_seq = code.tn_seq; + tn_seq->forward(training && code.op->isRequiredForBackward()); + + subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend); + } + subject.notifySubgraphEnd(profiling_subg_index); + } + else + { + for (auto &&index : _forward_order) + { + const auto &code = _code_map.at(index); +// TODO : Move ruy profiler into ExecutionObserver +#ifdef RUY_PROFILER + ruy::profiler::ScopeLabel label(code.op->name()); +#endif + auto &tn_seq = code.tn_seq; + tn_seq->forward(training && code.op->isRequiredForBackward()); + } + } +} + +void TrainableExecutor::backward(const ExecutionOptions &options, uint32_t training_step) +{ + // For thread-safe, use mutex + // TODO: if all used backends on this executor are thread-safe, + // do not need to use mutex (otherwise, use mutex) + std::lock_guard<std::mutex> lock(_mutex); + _current_options = options; + + // Create observee + ExecutionObservee subject(_observers, options); + + backwardImpl(subject, training_step); +} + +void TrainableExecutor::backwardImpl(const ExecutionObservee &subject, uint32_t training_step) +{ + if (!subject.isEmpty() && _tracing_ctx) + { + auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph()); + + subject.notifySubgraphBegin(profiling_subg_index); + for (auto &&index : _backward_order) + { + const auto &code = _code_map.at(index); + if (!code.op->isRequiredForBackward()) + { + continue; + } + const auto backend = code.lower_info->backend(); +// TODO : Move ruy profiler into ExecutionObserver +#ifdef RUY_PROFILER + ruy::profiler::ScopeLabel label(code.op->name()); +#endif + subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend); + + auto &tn_seq = code.tn_seq; + tn_seq->backward(training_step, code.op->isWeightsUpdateEnabled()); + + subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend); + } + subject.notifySubgraphEnd(profiling_subg_index); + } + else + { + for (auto &&index : _backward_order) + { + const auto &code = _code_map.at(index); + if (!code.op->isRequiredForBackward()) + { + continue; + } +// TODO : Move ruy profiler into ExecutionObserver +#ifdef RUY_PROFILER + ruy::profiler::ScopeLabel label(code.op->name()); +#endif + auto &tn_seq = code.tn_seq; + tn_seq->backward(training_step, code.op->isWeightsUpdateEnabled()); + } + } +} + +float TrainableExecutor::getLoss(const ir::IOIndex &pred_io_ind) const +{ + const auto &loss_ind = _trainable_graph.getLossIndex(pred_io_ind); + if (loss_ind.undefined()) + throw std::runtime_error{"Loss " + std::to_string(loss_ind.value()) + " is not defined."}; + backend::ITensor *tensor = _tensor_regs.getITensor(loss_ind); + long double sum = 0; + for (uint64_t i = 0; i < tensor->getShape().num_elements(); ++i) + { + sum += reinterpret_cast<float *>(tensor->buffer())[i]; + } + if (_loss_info.reduction_type == ir::train::LossReductionType::SumOverBatchSize) + { + sum /= tensor->getShape().num_elements(); + } + return static_cast<float>(sum); +} + +void TrainableExecutor::iterateTrainableTensors( + const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn) + const +{ + _tensor_regs.iterateTrainableTensors(fn); +} + +} // namespace train +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/train/TrainableExecutor.h b/runtime/onert/core/src/exec/train/TrainableExecutor.h new file mode 100644 index 000000000..986c2236c --- /dev/null +++ b/runtime/onert/core/src/exec/train/TrainableExecutor.h @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_ +#define __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_ + +#include "exec/IExecutor.h" + +#include "../ExecutionObservee.h" +#include "../../compiler/train/TensorRegistries.h" + +#include "backend/train/TrainableBackendContext.h" +#include "compiler/train/TrainableCodeMap.h" +#include "compiler/train/LoweredTrainableGraph.h" +#include "ir/train/LossInfo.h" +#include "ir/Index.h" +#include "util/TracingCtx.h" + +namespace onert +{ +namespace exec +{ +namespace train +{ + +class TrainableExecutor : public IExecutor +{ +public: + /** + * @brief Construct a new TrainableExecutor object + * @param lowered_graph LoweredTrainableGraph object + * @param tensor_builders Tensor builders that are currently used + * @param code_map @c ir::Operation and its code map + */ + TrainableExecutor(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph, + backend::train::TrainableBackendContexts &&backend_contexts, + const compiler::train::TensorRegistries &tensor_regs, + compiler::train::TrainableCodeMap &&code_map, + const std::vector<ir::OperationIndex> &forward_order, + const std::vector<ir::OperationIndex> &backward_order, + const util::TracingCtx *tracing_ctx, const ir::train::LossInfo &training_info); + +public: + const ir::Graph &graph() const final { return _trainable_graph.graph(); } + + void execute(const std::vector<backend::IPortableTensor *> &inputs, + const std::vector<backend::IPortableTensor *> &outputs, + const ExecutionOptions &options) override + { + forward(inputs, outputs, options, false); + } + + uint32_t inputSize() const override { return _input_tensors.size(); } + + uint32_t outputSize() const override { return _output_tensors.size(); } + + const ir::OperandInfo &inputInfo(uint32_t index) const override + { + return _input_tensors[index]->get_info(); + } + + const ir::OperandInfo &outputInfo(uint32_t index) const override + { + return _output_tensors[index]->get_info(); + } + + ir::Layout inputLayout(uint32_t index) const override { return _input_tensors[index]->layout(); } + + ir::Layout outputLayout(uint32_t index) const override + { + return _output_tensors[index]->layout(); + } + + void forward(const std::vector<backend::IPortableTensor *> &inputs, + const std::vector<backend::IPortableTensor *> &outputs, + const ExecutionOptions &options, bool training); + void backward(const ExecutionOptions &options, uint32_t training_step); + + // Used only in Dataflow and Parallel Executors + void setIndexedRanks(std::shared_ptr<ir::OperationIndexMap<int64_t>> ranks) final + { + _indexed_ranks = std::move(ranks); + }; + + void addObserver(std::unique_ptr<IExecutionObserver> ref) { _observers.add(std::move(ref)); }; + + float getLoss(const ir::IOIndex &pred_io_ind) const; + + void iterateTrainableTensors( + const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> + &fn) const; + + backend::train::TrainableBackendContexts &getBackendContexts() { return _backend_contexts; } + + const ExecutionOptions ¤tOptions() const override { return _current_options; } + +private: + void forwardImpl(const ExecutionObservee &subject, bool training); + void backwardImpl(const ExecutionObservee &subject, uint32_t training_step); + +private: + compiler::train::TrainableCodeMap _code_map; + std::vector<ir::OperationIndex> _forward_order; + std::vector<ir::OperationIndex> _backward_order; + ExecObservers _observers; + std::shared_ptr<ir::OperationIndexMap<int64_t>> _indexed_ranks; + std::unique_ptr<compiler::train::LoweredTrainableGraph> _lowered_graph; + backend::train::TrainableBackendContexts _backend_contexts; + const ir::train::TrainableGraph &_trainable_graph; + compiler::train::TensorRegistries _tensor_regs; + std::vector<backend::builtin::IOTensor *> _input_tensors; + std::vector<backend::builtin::IOTensor *> _output_tensors; + std::mutex _mutex; + const util::TracingCtx *_tracing_ctx; + const ir::train::LossInfo _loss_info; + /** + * It is set by execute() method only in thread-safe environment. + * It is used for non-primary executor call on builtin backend + * and accessed by entryExecutor's currentOptions() method. + * + * TODO: Find better way to pass config to non-primary executor + */ + ExecutionOptions _current_options; +}; + +} // namespace train +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_ diff --git a/runtime/onert/core/src/exec/train/TrainableExecutors.cc b/runtime/onert/core/src/exec/train/TrainableExecutors.cc new file mode 100644 index 000000000..73217c836 --- /dev/null +++ b/runtime/onert/core/src/exec/train/TrainableExecutors.cc @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TrainableExecutors.h" + +#include "../../backend/builtin/IOTensor.h" + +#include <misc/polymorphic_downcast.h> + +namespace onert +{ +namespace exec +{ +namespace train +{ + +void TrainableExecutors::emplace(const ir::ModelIndex &, const ir::SubgraphIndex &subg_index, + std::unique_ptr<IExecutor> exec) +{ + std::unique_ptr<TrainableExecutor> t_exec{ + nnfw::misc::polymorphic_downcast<TrainableExecutor *>(exec.release())}; + _executors.emplace(subg_index, std::move(t_exec)); +} + +TrainableExecutor *TrainableExecutors::at(const ir::ModelIndex &, + const ir::SubgraphIndex &subg_index) const +{ + return _executors.at(subg_index).get(); +} + +uint32_t TrainableExecutors::inputSize() const { return entryExecutor()->inputSize(); } + +uint32_t TrainableExecutors::outputSize() const { return entryExecutor()->outputSize(); } + +const ir::OperandInfo &TrainableExecutors::inputInfo(const ir::IOIndex &index) const +{ + return entryExecutor()->inputInfo(index.value()); +} + +const ir::OperandInfo &TrainableExecutors::outputInfo(const ir::IOIndex &index) const +{ + return entryExecutor()->outputInfo(index.value()); +} + +void TrainableExecutors::execute(const ExecutionContext &ctx) +{ + if (_executors.size() > 1) + throw std::runtime_error("TrainableExecutors does not support multiple executors yet"); + + // UserTensor for Input/Output + std::vector<std::unique_ptr<backend::builtin::UserTensor>> tensorpool; + + // Allocate UserTensor and call executor forward + forward(ctx, tensorpool, false); + + // TODO Support multple executors +} + +void TrainableExecutors::train(const ExecutionContext &ctx, uint32_t training_step) +{ + if (_executors.size() > 1) + throw std::runtime_error("TrainableExecutors does not support multiple executors yet"); + + // UserTensor for Input/Output + std::vector<std::unique_ptr<backend::builtin::UserTensor>> tensorpool; + + // Allocate UserTensor and call executor forward and backward + forward(ctx, tensorpool, true); + entryExecutor()->backward(ctx.options, training_step); + + // TODO Support multple executors +} + +void TrainableExecutors::forward( + const ExecutionContext &ctx, + std::vector<std::unique_ptr<backend::builtin::UserTensor>> &tensorpool, bool training) +{ + // Input/Output Tensor vector for executor + std::vector<backend::IPortableTensor *> inputs(ctx.desc.inputs.size()); + std::vector<backend::IPortableTensor *> outputs(ctx.desc.outputs.size()); + + // Prepare UserTensor for input + for (uint32_t i = 0; i < inputs.size(); i++) + { + auto &desc = ctx.desc.inputs[i]; + + // Input is optional if buffer is nullptr, and optional input's size is 0 + if (desc->buffer == nullptr && (desc->size != 0 || desc->info.total_size() != 0)) + throw std::runtime_error{"Input " + std::to_string(i) + "'s buffer is not set."}; + + tensorpool.emplace_back(std::make_unique<backend::builtin::UserTensor>( + desc->info, desc->layout, const_cast<uint8_t *>(static_cast<const uint8_t *>(desc->buffer)), + desc->size)); + inputs[i] = tensorpool.back().get(); + } + + // Prepare UserTensor for output + for (uint32_t i = 0; i < outputs.size(); i++) + { + auto &desc = ctx.desc.outputs[i]; + + // If training, output buffer may not be used + // So don't check optional + tensorpool.emplace_back(std::make_unique<backend::builtin::UserTensor>( + desc->info, desc->layout, static_cast<uint8_t *>(desc->buffer), desc->size)); + outputs[i] = tensorpool.back().get(); + } + + // Call forward + entryExecutor()->forward(inputs, outputs, ctx.options, training); +} + +float TrainableExecutors::getLoss(const ir::IOIndex &index) const +{ + if (_executors.size() > 1) + throw std::runtime_error("TrainableExecutors does not support multiple executors yet"); + return entryExecutor()->getLoss(index); +} + +void TrainableExecutors::iterateTrainableTensors( + const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn) + const +{ + return entryExecutor()->iterateTrainableTensors(fn); +} + +} // namespace train +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/train/TrainableExecutors.h b/runtime/onert/core/src/exec/train/TrainableExecutors.h new file mode 100644 index 000000000..ae120f6f0 --- /dev/null +++ b/runtime/onert/core/src/exec/train/TrainableExecutors.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__ +#define __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__ + +#include "TrainableExecutor.h" +#include "exec/IExecutors.h" +#include "ir/NNPkg.h" + +namespace onert +{ +namespace exec +{ +namespace train +{ + +/** + * @brief Class to gather executor set for trainable model NN package + */ +class TrainableExecutors : public IExecutors +{ +public: + /** + * @brief Construct a new TrainableExecutors object + */ + TrainableExecutors(void) = default; + TrainableExecutors(const TrainableExecutors &) = delete; + TrainableExecutors(TrainableExecutors &&) = default; + + /** + * @brief Destroy the TrainableExecutors object + */ + ~TrainableExecutors() = default; + +public: + TrainableExecutors &operator=(const TrainableExecutors &) = delete; + TrainableExecutors &operator=(TrainableExecutors &&) = default; + +public: + void emplace(const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index, + std::unique_ptr<IExecutor> exec) override; + + TrainableExecutor *at(const ir::ModelIndex &model_index, + const ir::SubgraphIndex &subg_index) const override; + + TrainableExecutor *entryExecutor() const { return at(ir::ModelIndex{0}, ir::SubgraphIndex{0}); } + + uint32_t inputSize() const override; + + uint32_t outputSize() const override; + + const ir::OperandInfo &inputInfo(const ir::IOIndex &index) const override; + + const ir::OperandInfo &outputInfo(const ir::IOIndex &index) const override; + + void execute(const ExecutionContext &ctx) override; + + /** + * @brief Train + * + * @param ctx Execution context + * @param training_step The number of iterations of an training process. + * In other words, the number of gradient update. + */ + void train(const ExecutionContext &ctx, uint32_t training_step); + + float getLoss(const ir::IOIndex &index) const; + + void iterateTrainableTensors( + const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> + &fn) const; + +private: + // If you want to use I/O buffer on step, tensorpool should be alive until one step is finished + // So this method get tensorpool from outside. + // tensorpool is not defined as a member variable to avoid memory access conflict between threads. + void forward(const ExecutionContext &ctx, + std::vector<std::unique_ptr<backend::builtin::UserTensor>> &tensorpool, + bool training); + +private: + // TODO Append model index to ModelIndex + std::unordered_map<ir::SubgraphIndex, std::unique_ptr<TrainableExecutor>> _executors; +}; + +} // namespace train +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__ diff --git a/runtime/onert/core/src/exec/train/TrainableFnSequence.cc b/runtime/onert/core/src/exec/train/TrainableFnSequence.cc new file mode 100644 index 000000000..36e4c3171 --- /dev/null +++ b/runtime/onert/core/src/exec/train/TrainableFnSequence.cc @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "exec/train/TrainableFnSequence.h" + +namespace onert +{ +namespace exec +{ +namespace train +{ + +void TrainableFnSequence::forward(bool training) +{ + for (const auto &function : _functions) + { + function->forward(training); + } +} + +void TrainableFnSequence::backward(uint32_t training_step, bool weight_update_enabled) +{ + for (auto it = _functions.rbegin(); it != _functions.rend(); ++it) + { + (*it)->backward(); + } + if (weight_update_enabled) + { + for (const auto &applier : _appliers) + { + applier->applyGradient(training_step); + } + } +} + +void TrainableFnSequence::append(std::unique_ptr<ITrainableFunction> &&function) +{ + _functions.push_back(std::move(function)); +} + +void TrainableFnSequence::append(std::unique_ptr<IGradientApplier> &&applier) +{ + _appliers.push_back(std::move(applier)); +} + +void TrainableFnSequence::iterate(const std::function<void(ITrainableFunction &)> &fn) +{ + for (const auto &func : _functions) + { + fn(*func); + } +} + +} // namespace train +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exporter/CircleExporter.cc b/runtime/onert/core/src/exporter/CircleExporter.cc new file mode 100644 index 000000000..b9ac8d5bb --- /dev/null +++ b/runtime/onert/core/src/exporter/CircleExporter.cc @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "exporter/CircleExporter.h" + +#include "exec/Execution.h" +#include "ir/train/TrainingInfo.h" +#include "circle_schema_generated.h" +#include "TrainInfoBuilder.h" + +#include <fstream> +#include <iostream> + +namespace onert +{ +namespace exporter +{ + +CircleExporter::CircleExporter(const std::string &source, const std::string &path) + : _path{path}, _data{}, _model{nullptr} +{ + // make sure the architecture is little endian before direct access to flatbuffers + assert(FLATBUFFERS_LITTLEENDIAN); + + std::ifstream src(source.c_str(), std::ios::binary); + if (src.is_open()) + { + src.seekg(0, std::ios::end); + _data.resize(src.tellg()); + src.seekg(0, std::ios::beg); + src.read(&_data[0], static_cast<std::streamsize>(_data.size())); + src.close(); + } + + if (_data.size() == 0) + throw std::runtime_error("Invalid source file"); + + const auto model = ::circle::GetModel(_data.data()); + if (!model) + throw std::runtime_error("Failed to load original circle file"); + _model.reset(model->UnPack()); +} + +CircleExporter::~CircleExporter() { finish(); } + +void CircleExporter::updateWeight(const std::unique_ptr<exec::Execution> &exec) +{ + exec->iterateTrainableTensors( + [&](const ir::OperandIndex &idx, const backend::train::ITrainableTensor *tensor) { + std::lock_guard<std::mutex> guard(_mutex); + const auto &subgs = _model->subgraphs; + if (subgs.size() != 1) + throw std::runtime_error("Circle does not has valid subgraph or has multiple subgraphs"); + + if (!idx.valid()) + throw std::runtime_error("Trainable tensor is invalid"); + + uint32_t buf_idx = -1; + const auto &subg = subgs.at(0); // Get 1st subgraph + if (idx.value() >= subg->tensors.size()) + { + auto buffer = std::make_unique<::circle::BufferT>(); + buffer->size = tensor->total_size(); + buffer->data.resize(buffer->size); + + buf_idx = _model->buffers.size(); + _model->buffers.push_back(std::move(buffer)); + } + else + { + buf_idx = subg->tensors.at(idx.value())->buffer; + if (buf_idx >= _model->buffers.size()) + throw std::runtime_error("Buffer for trainable tensors is invalid"); + } + + const auto &buffer = _model->buffers.at(buf_idx); + + auto org_buf_sz = buffer->data.size(); + if (org_buf_sz != tensor->total_size()) + throw std::runtime_error("Trained tensor buffer size does not match original tensor's one"); + + memcpy(buffer->data.data(), tensor->buffer(), org_buf_sz); + }); +} + +void CircleExporter::updateMetadata(const std::unique_ptr<ir::train::TrainingInfo> &training_info) +{ + const char *const TRAININFO_METADATA_NAME = "CIRCLE_TRAINING"; + + TrainInfoBuilder tbuilder(training_info); + bool found = false; + for (const auto &meta : _model->metadata) + { + if (meta->name == std::string{TRAININFO_METADATA_NAME}) + { + std::lock_guard<std::mutex> guard(_mutex); + const uint32_t buf_idx = meta->buffer; + auto &buffer = _model->buffers.at(buf_idx); + + if (tbuilder.size() != buffer->data.size()) + { + buffer->data.resize(tbuilder.size()); + buffer->size = tbuilder.size(); + } + + memcpy(buffer->data.data(), tbuilder.get(), tbuilder.size()); + found = true; + break; + } + } + + if (!found) + { + std::lock_guard<std::mutex> guard(_mutex); + auto buffer = std::make_unique<::circle::BufferT>(); + buffer->size = tbuilder.size(); + buffer->data.resize(buffer->size); + memcpy(buffer->data.data(), tbuilder.get(), buffer->size); + + auto meta = std::make_unique<::circle::MetadataT>(); + meta->name = std::string{TRAININFO_METADATA_NAME}; + meta->buffer = _model->buffers.size(); + + _model->buffers.push_back(std::move(buffer)); + _model->metadata.push_back(std::move(meta)); + } +} + +void CircleExporter::finish() +{ + flatbuffers::FlatBufferBuilder builder(1024); + builder.Finish(::circle::Model::Pack(builder, _model.get()), ::circle::ModelIdentifier()); + + std::ofstream dst(_path.c_str(), std::ios::binary); + dst.write(reinterpret_cast<const char *>(builder.GetBufferPointer()), + static_cast<std::streamsize>(builder.GetSize())); + dst.close(); +} +} // namespace exporter +} // namespace onert diff --git a/runtime/onert/core/src/exporter/TrainInfoBuilder.h b/runtime/onert/core/src/exporter/TrainInfoBuilder.h new file mode 100644 index 000000000..c3084b462 --- /dev/null +++ b/runtime/onert/core/src/exporter/TrainInfoBuilder.h @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXPORTER_TRAININFO_BUILDER_H__ +#define __ONERT_EXPORTER_TRAININFO_BUILDER_H__ + +#include "ir/train/TrainingInfo.h" +#include "circle_schema_generated.h" +#include "circle_traininfo_generated.h" + +namespace onert +{ +namespace exporter +{ + +class TrainInfoBuilder +{ +public: + TrainInfoBuilder(const std::unique_ptr<ir::train::TrainingInfo> &training_info) : _builder(1024) + { + const auto &optimizerInfo = training_info->optimizerInfo(); + const auto &lossInfo = training_info->lossInfo(); + + ::circle::Optimizer optimizer; + ::circle::OptimizerOptions optimizer_opt_type; + ::flatbuffers::Offset<void> optimizer_opt; + switch (optimizerInfo.optim_code) + { + case ir::train::OptimizerCode::SGD: + optimizer = ::circle::Optimizer_SGD; + optimizer_opt_type = ::circle::OptimizerOptions_SGDOptions; + optimizer_opt = ::circle::CreateSGDOptions(_builder, optimizerInfo.learning_rate).Union(); + break; + case ir::train::OptimizerCode::Adam: + optimizer = ::circle::Optimizer_ADAM; + optimizer_opt_type = ::circle::OptimizerOptions_AdamOptions; + optimizer_opt = ::circle::CreateAdamOptions(_builder, optimizerInfo.learning_rate).Union(); + break; + default: + throw std::runtime_error("Not supported optimizer code"); + } + + ::circle::LossFn lossfn; + ::circle::LossFnOptions lossfn_opt_type; + ::flatbuffers::Offset<void> lossfn_opt; + switch (lossInfo.loss_code) + { + case ir::train::LossCode::MeanSquaredError: + lossfn = ::circle::LossFn_MEAN_SQUARED_ERROR; + lossfn_opt_type = ::circle::LossFnOptions_MeanSquaredErrorOptions; + lossfn_opt = ::circle::CreateMeanSquaredErrorOptions(_builder).Union(); + break; + case ir::train::LossCode::CategoricalCrossentropy: + lossfn = ::circle::LossFn_CATEGORICAL_CROSSENTROPY; + lossfn_opt_type = ::circle::LossFnOptions_CategoricalCrossentropyOptions; + lossfn_opt = ::circle::CreateCategoricalCrossentropyOptions(_builder).Union(); + break; + default: + throw std::runtime_error("Not supported loss code"); + } + + ::circle::LossReductionType loss_reduction_type; + switch (lossInfo.reduction_type) + { + case ir::train::LossReductionType::SumOverBatchSize: + loss_reduction_type = ::circle::LossReductionType_SumOverBatchSize; + break; + case ir::train::LossReductionType::Sum: + loss_reduction_type = ::circle::LossReductionType_Sum; + break; + default: + throw std::runtime_error("Not supported loss reduction type"); + } + + std::vector<int32_t> trainable_ops; + for (const auto &op : training_info->getTrainableOps()) + { + trainable_ops.push_back(op.value()); + } + + const auto end = ::circle::CreateModelTrainingDirect( + _builder, training_info->version(), optimizer, optimizer_opt_type, optimizer_opt, lossfn, + lossfn_opt_type, lossfn_opt, 0, training_info->batchSize(), loss_reduction_type, + &trainable_ops); + _builder.Finish(end, ::circle::ModelTrainingIdentifier()); + + ::flatbuffers::Verifier v(_builder.GetBufferPointer(), _builder.GetSize()); + bool verified = ::circle::VerifyModelTrainingBuffer(v); + if (not verified) + throw std::runtime_error{"TrainingInfo buffer is not accessible"}; + } + + uint8_t *get() const { return _builder.GetBufferPointer(); } + uint32_t size() const { return _builder.GetSize(); } + +private: + ::flatbuffers::FlatBufferBuilder _builder; +}; + +} // namespace exporter +} // namespace onert + +#endif // __ONERT_EXPORTER_TRAININFO_BUILDER_H__ diff --git a/runtime/onert/core/src/interp/Buffer.h b/runtime/onert/core/src/interp/Buffer.h deleted file mode 100644 index 24938f74f..000000000 --- a/runtime/onert/core/src/interp/Buffer.h +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file Buffer.h - * @brief This file contains Buffer interface and InternalBuffer, ExternalBuffer class - */ -#ifndef __ONERT_INTERP_BUFFER_H__ -#define __ONERT_INTERP_BUFFER_H__ - -#include <memory> - -#include "ir/Data.h" - -namespace onert -{ -namespace interp -{ - -/** - * @brief Interface for writable data area - */ -class Buffer : public ir::Data -{ -public: - /** - * @brief Return writable pointer for data area - * @return Writable pointer - */ - virtual uint8_t *baseWritable(void) const = 0; -}; - -/** - * @brief Class for internally allocated data area - */ -class InternalBuffer final : public Buffer -{ -public: - InternalBuffer(size_t size) : _base{std::make_unique<uint8_t[]>(size)}, _size{size} - { - // DO NOTHING - } - -public: - size_t size(void) const override { return _size; } - const uint8_t *base(void) const override { return _base.get(); } - uint8_t *baseWritable(void) const override { return _base.get(); } - -private: - std::unique_ptr<uint8_t[]> _base; - size_t _size; -}; - -/** - * @brief Class for data area from outside - */ -class ExternalBuffer final : public Buffer -{ -public: - ExternalBuffer(uint8_t *base, size_t size) : _base{base}, _size{size} - { - // DO NOTHING - } - -public: - size_t size(void) const override { return _size; } - const uint8_t *base(void) const override { return _base; } - uint8_t *baseWritable(void) const override { return _base; } - -private: - uint8_t *_base; - size_t _size; -}; - -} // namespace interp -} // namespace onert - -#endif // __ONERT_INTERP_BUFFER_H__ diff --git a/runtime/onert/core/src/interp/ExecEnv.h b/runtime/onert/core/src/interp/ExecEnv.h deleted file mode 100644 index 7f577ea6e..000000000 --- a/runtime/onert/core/src/interp/ExecEnv.h +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file ExecEnv.h - * @brief This file contains ExecEnv to access interpreter tensor and execution status - */ -#ifndef __ONERT_INTERP_EXEC_ENV_H_ -#define __ONERT_INTERP_EXEC_ENV_H_ - -#include <unordered_set> - -#include "ir/Graph.h" -#include "Tensor.h" - -namespace onert -{ -namespace interp -{ - -/** - * @brief Class to gather interpreter execution environment - * Each interpreter instance own execution environment - */ -class ExecEnv -{ -public: - /** - * @brief Construct a new Exec Env object (deleted) - */ - ExecEnv(void) = delete; - /** - * @brief Construct a new ExecEnv object - * @param[in] graph Graph to execute by interpreter - */ - explicit ExecEnv(const ir::Graph &graph) : _graph(graph) - { - // DO NOTHING - } - -public: - /** - * @brief Return graph to execute - * @return Graph - */ - const ir::Graph &graph(void) const { return _graph; } - /** - * @brief Assign tensor to environment which have allocated or assigned buffer - * @param[in] index Tensor index - * @param[in] tensor Tensor - */ - void assignTensor(const ir::OperandIndex index, std::shared_ptr<ITensor> tensor) - { - assert(tensor->bufferRO() != nullptr); - _tensors.emplace(index, tensor); - } - - /** - * @brief Return tensor pointer in environment - * @param[in] index Tensor index - * can_optional @c True if tensor can be optional input, otherwise @c false - * @return Tensor pointer - */ - const ITensor *tensorAt(const ir::OperandIndex index, bool can_optional = false) const - { - if (_tensors.find(index) == _tensors.end()) - { - // It may optional input, - // otherwise input is not set by runtime user - if (can_optional) - { - return nullptr; - } - - throw std::runtime_error{"ExecEnv: Input is not set"}; - } - - return _tensors.at(index).get(); - } - - /** - * @brief Check environment contains tensor - * @param[in] index Tensor index - * @return @c true if environment contain tensor, otherwise @c false - */ - bool contains(const ir::OperandIndex index) const - { - return (_tensors.find(index) != _tensors.end()); - } - - /** - * @brief Allocate tensor using operand info - * @param[in] index Tensor index - * @param[in] info Operand info - * @note If already allocated, just return - * @TODO More smart allocation policy - */ - void allocateIfNeeded(const ir::OperandIndex index, const ir::OperandInfo &info) - { - // already allocated, or constant - if (contains(index)) - { - return; - } - - // Buffer from external (ex. model output) - auto tensor = std::make_shared<Tensor>(info); - if (isExtBuffer(index)) - { - tensor->setBuffer(_external_buffers.at(index)); - assignTensor(index, tensor); - - return; - } - - tensor->setBuffer(std::make_shared<InternalBuffer>(tensor->total_size())); - assignTensor(index, tensor); - _buffers.insert(index); - } - - /** - * @brief Allocate read-only tensor and share data with other tensor - * @param[in] index Tensor index - * @param[in] info Operand info - * @param[in] index_to_share Tensor index that have data to share - */ - void allocateAndShareIfNeeded(const ir::OperandIndex index, const ir::OperandInfo &info, - const ir::OperandIndex index_to_share) - { - if (!contains(index_to_share)) - { - throw std::runtime_error{"Cannot find tensor to share data"}; - } - - // already allocated - if (contains(index)) - { - return; - } - - if (isExtBuffer(index)) - { - auto tensor = std::make_shared<Tensor>(info); - tensor->setBuffer(_external_buffers.at(index)); - assignTensor(index, tensor); - } - else - { - auto tensor = std::make_shared<ROTensor>(info); - tensor->setData(tensorAt(index_to_share)->shareData()); - assignTensor(index, tensor); - _buffers.insert(index); - } - } - - /** - * @brief Free buffer if allocated by allocateIfNeed - * @param[in] index Tensor index - * @note If allocated by outside, just return - */ - void freeIfAllocated(const ir::OperandIndex index) - { - if (_buffers.find(index) != _buffers.end()) - { - _tensors.at(index)->releaseData(); - } - } - - /** - * @brief Assign ExternalBuffer into external buffer map - * @param[in] index Tensor index - * @param[in] buffer External buffer - */ - void assignExternalBuffer(const ir::OperandIndex index, std::shared_ptr<ExternalBuffer> buffer) - { - _external_buffers.emplace(index, buffer); - } - -private: - bool isExtBuffer(const ir::OperandIndex index) - { - return (_external_buffers.find(index) != _external_buffers.end()); - } - -private: - const ir::Graph &_graph; - // Tensor map to use in interpreter - // It should map tensors that have allocated or assigned buffer pointer - std::unordered_map<ir::OperandIndex, std::shared_ptr<ITensor>> _tensors; - // Tensors allocated by allocateIfNeed (buffer) - std::unordered_set<ir::OperandIndex> _buffers; - // Tensor buffer from external - std::unordered_map<ir::OperandIndex, std::shared_ptr<ExternalBuffer>> _external_buffers; -}; - -} // namespace interp -} // namespace onert - -#endif // __ONERT_INTERP_EXEC_ENV_H_ diff --git a/runtime/onert/core/src/interp/InterpExecutor.cc b/runtime/onert/core/src/interp/InterpExecutor.cc deleted file mode 100644 index cd31a4dca..000000000 --- a/runtime/onert/core/src/interp/InterpExecutor.cc +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "interp/InterpExecutor.h" -#include "interp/ExecEnv.h" -#include "interp/Interpreter.h" - -#include "util/logging.h" - -#include <memory> - -namespace onert -{ -namespace interp -{ - -void InterpExecutor::execute(const exec::IODescription &desc) -{ - /************************************************************************ - * Prepare execution model (submodel) - It may execute divided model - but now consider model inference is done at interpreter - ***********************************************************************/ - ir::OperandIndexMap<std::shared_ptr<ITensor>> tensor_map; - - for (uint32_t n = 0; n < _graph.getInputs().size(); n++) - { - ir::IOIndex index{n}; - const auto input_index = _graph.getInputs().at(index); - - const auto input = desc.inputs.at(n).get(); - if (input == nullptr) - { - // Optional input - continue; - } - - auto input_tensor = std::make_shared<ROTensor>(input->info); - input_tensor->setData(std::make_shared<const ir::ExternalData>( - reinterpret_cast<const uint8_t *>(input->buffer), input->size)); - tensor_map[input_index] = input_tensor; - } - - /************************************************************************ - * Prepare execution environment - Execution environment will be assigned to invoked interpreter instance - ***********************************************************************/ - - std::unique_ptr<ExecEnv> interp_env = std::make_unique<ExecEnv>(_graph); - - // Assign input/output tensor into interpreter execution environment - for (auto index : _graph.getInputs()) - { - if (tensor_map.find(index) != tensor_map.end()) - { - VERBOSE(INTERPRETER) << "Assign input tensor. operand index:" << index.value() << std::endl; - interp_env->assignTensor(index, tensor_map.at(index)); - } - } - - for (uint32_t n = 0; n < _graph.getOutputs().size(); n++) - { - ir::IOIndex index{n}; - const auto output_index = _graph.getOutputs().at(index); - const auto output = desc.outputs.at(n).get(); - if (output == nullptr) - { - // Optional output - continue; - } - - VERBOSE(INTERPRETER) << "Set out buffer to ExecEnv. operand index:" << output_index.value() - << std::endl; - - interp_env->assignExternalBuffer( - output_index, std::make_shared<ExternalBuffer>(reinterpret_cast<uint8_t *>(output->buffer), - output->size)); - } - - // Allocate constant tensor - _graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) { - if (obj.isConstant()) - { - VERBOSE(INTERPRETER) << "Allocate and assign constant tensor. operand index:" << ind.value() - << std::endl; - - assert(obj.data()); - auto const_tensor = std::make_shared<ROTensor>(obj.info()); - // Assume that interpreter's tensor layout is same with model (NHWC) - const_tensor->setData( - std::make_shared<ir::ExternalData>(obj.data()->base(), obj.info().total_size())); - interp_env->assignTensor(ind, const_tensor); - } - }); - - /***************************************************************************** - * Invoke interpreter - ****************************************************************************/ - - interp::Interpreter interp(std::move(interp_env)); - interp.run(); - - /***************************************************************************** - * Invoked interpreter run is finished - ****************************************************************************/ - - // If interpreter execute submodel - // 1. Get tensor output of submodel into tensor_map to save result - // 2. Generate new ExecEnv for next interpretation -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/InterpExecutor.h b/runtime/onert/core/src/interp/InterpExecutor.h deleted file mode 100644 index 2e3f3ca54..000000000 --- a/runtime/onert/core/src/interp/InterpExecutor.h +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file InterpExecutor.h - * @brief This file contains InterpExecutor class\n - * to manage interpreter execution and environment - */ -#ifndef __ONERT_INTERP_INTERP_EXECUTOR_H__ -#define __ONERT_INTERP_INTERP_EXECUTOR_H__ - -#include "ir/OperandIndexMap.h" -#include "ir/Graph.h" -#include "exec/IExecutor.h" - -namespace onert -{ -namespace interp -{ - -class ITensor; - -/** - * @brief Class to execute model using interpreter - */ -class InterpExecutor final : public exec::IExecutor -{ -public: - explicit InterpExecutor(const ir::Graph &graph) : _graph(graph) - { - // DO NOTHING - } - -public: - /** - * @brief Return graph object - * @return Graph object - */ - const ir::Graph &graph() final { return _graph; } - void setIndexedRanks(std::shared_ptr<ir::OperationIndexMap<int64_t>>) override{ - // Not implemented - }; - /** - * @brief Start execution - * @note It should be called after setting input and output buffer - */ - void execute(const exec::IODescription &desc) final; - -private: - const ir::Graph &_graph; - ir::OperandIndexMap<std::shared_ptr<ITensor>> _tensor_map; -}; - -} // namespace interp -} // namespace onert - -#endif // __ONERT_INTERP_INTERP_EXECUTOR_H__ diff --git a/runtime/onert/core/src/interp/InterpOps.lst b/runtime/onert/core/src/interp/InterpOps.lst deleted file mode 100644 index 0714df38a..000000000 --- a/runtime/onert/core/src/interp/InterpOps.lst +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INTERP_OP -#error Define INTERP_OP before including this file -#endif - -// Supported operation name in interpreter -// -// Same list with Operations.lst -// Make comment out if operation is not supported in interpreter -INTERP_OP(BinaryArithmetic) -//INTERP_OP(BatchToSpaceND) -//INTERP_OP(Cast) -INTERP_OP(Conv2D) -INTERP_OP(DepthwiseConv2D) -INTERP_OP(Pool2D) -INTERP_OP(Concat) -INTERP_OP(FullyConnected) -//INTERP_OP(Reduce) -INTERP_OP(Reshape) -INTERP_OP(Softmax) -//INTERP_OP(Squeeze) -//INTERP_OP(Slice) -//INTERP_OP(StridedSlice) -INTERP_OP(ElementwiseActivation) -//INTERP_OP(Transpose) -//INTERP_OP(Exp) -//INTERP_OP(Comparison) -//INTERP_OP(LogicalNot) -//INTERP_OP(LSTM) -//INTERP_OP(RSQRT) -//INTERP_OP(ResizeBilinear) -//INTERP_OP(RNN) -//INTERP_OP(Floor) -//INTERP_OP(SpaceToBatchND) -//INTERP_OP(SpaceToDepth) -//INTERP_OP(EmbeddingLookup) -//INTERP_OP(L2Normalization) -//INTERP_OP(HashtableLookup) -INTERP_OP(InstanceNorm) -//INTERP_OP(PReLU) -INTERP_OP(TransposeConv) -//INTERP_OP(SQRT) -//INTERP_OP(SquaredDifference) -//INTERP_OP(TopKV2) -INTERP_OP(Gather) -//INTERP_OP(Neg) -//INTERP_OP(Abs) -//INTERP_OP(ArgMax) -//INTERP_OP(Dequantize) -//INTERP_OP(LocalResponseNormalization) -//INTERP_OP(DepthToSpace) -//INTERP_OP(Pack) -//INTERP_OP(Split) -//INTERP_OP(Unpack) -INTERP_OP(Pad) -//INTERP_OP(Custom) -//INTERP_OP(Permute) -//INTERP_OP(OneHot) diff --git a/runtime/onert/core/src/interp/Interpreter.cc b/runtime/onert/core/src/interp/Interpreter.cc deleted file mode 100644 index b92afbe73..000000000 --- a/runtime/onert/core/src/interp/Interpreter.cc +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "Interpreter.h" - -#include <stack> -#include <unordered_set> - -#include "Registration.h" - -#include "ir/OperandIndexMap.h" -#include "util/logging.h" -#include "ir/OperationVisitor.h" - -namespace onert -{ -namespace interp -{ - -// TODO more structured execution kernel implementation -// TODO use cker for execution -// TODO divide tensor prepare and execution -// TODO introduce memory manager (buffer allocate and free) -class OperationExecutor -{ -public: - OperationExecutor(ExecEnv *env) : _env{env} - { -#define INTERP_OP(InternalName) _kernels[ir::OpCode::InternalName] = get##InternalName(); -#include "InterpOps.lst" -#undef INTERP_OP - } - - void execute(const ir::OperationIndex &idx) - { - const ir::Operation &node = _env->graph().operations().at(idx); - const auto nodeName = node.name(); - VERBOSE(INTERPRETER) << "Prepare output operands and execute " << nodeName - << " operation (id: " << idx.value() << ")" << std::endl; - - const auto nodeOpCode = node.opcode(); - if (_kernels.find(nodeOpCode) == _kernels.end()) - { - throw std::runtime_error{"Interpreter: Operation " + nodeName + " is not yet implemented"}; - } - - if (_kernels[nodeOpCode]->prepare != nullptr) - { - _kernels[nodeOpCode]->prepare(_env, node); - } - _kernels[nodeOpCode]->invoke(_env, node); - } - -private: - ExecEnv *_env; - std::unordered_map<ir::OpCode, OpKernel *> _kernels; -}; - -void Interpreter::run() -{ - VERBOSE(INTERPRETER) << "Interpreter is invoked " << std::endl; - - // operand_stack: save operands prepared to use - std::stack<ir::OperandIndex> operand_stack; - - // Note: We should push input first, then constant. - // We use use-def for find operators ready to execution, - // but Use-Def cannot handle parameters (maybe constant, but not always) - // Note: If all model inputs are constant, it may not work (depend on tensors' order). - // But that scenario may not exist - for (auto ind : _env->graph().getInputs()) - { - VERBOSE(INTERPRETER) << "Input: Push to operand stack " << ind.value() << std::endl; - - operand_stack.push(ind); - } - - _env->graph().operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) { - if (obj.isConstant()) - { - VERBOSE(INTERPRETER) << "Constant: Push to operand stack " << ind.value() << std::endl; - - operand_stack.push(ind); - } - }); - - // Execution - std::unordered_set<ir::OperandIndex> ready_check; - std::unordered_set<ir::OperationIndex> executed; - OperationExecutor executor{_env.get()}; - while (!operand_stack.empty()) - { - const auto current_operand_index = operand_stack.top(); - operand_stack.pop(); - VERBOSE(INTERPRETER) << "Poped operand " << current_operand_index.value() - << " is checked ready to use" << std::endl; - - assert(ready_check.find(current_operand_index) == ready_check.end()); - ready_check.insert(current_operand_index); - - // Find prepared operations by scan use of current operand - std::stack<ir::OperationIndex> operation_stack; - const auto use_operators = _env->graph().operands().at(current_operand_index).getUses(); - for (const auto &use_operator : use_operators) - { - // Assumption: all parameters are ready to use - bool operator_ready = true; - for (auto input_index : _env->graph().operations().at(use_operator).getInputs()) - { - if (ready_check.find(input_index) == ready_check.end()) - { - operator_ready = false; - break; - } - } - - if (operator_ready) - { - VERBOSE(INTERPRETER) << "Ready to execute operation " << use_operator.value() << std::endl; - operation_stack.push(use_operator); - } - } - - while (!operation_stack.empty()) - { - const auto current_operation_index = operation_stack.top(); - operation_stack.pop(); - VERBOSE(INTERPRETER) << "Poped operation: " << current_operation_index.value() << "(" - << _env->graph().operations().at(current_operation_index).name() << ")" - << std::endl; - - // execution - // 1. Prepare output tensor - // 2. Call operation kernel - executor.execute(current_operation_index); - executed.insert(current_operation_index); - - // 3. Push each output into operand stack - const auto def_operands = _env->graph().operations().at(current_operation_index).getOutputs(); - for (auto def_operand : def_operands) - { - VERBOSE(INTERPRETER) << "Buffer: Push to operand stack " << def_operand.value() - << std::endl; - operand_stack.push(def_operand); - } - - // 4. Free if lifetime of buffer operands used by input is finished - for (auto input_index : _env->graph().operations().at(current_operation_index).getInputs()) - { - const auto use_operators = _env->graph().operands().at(input_index).getUses(); - bool dead_buffer = true; - for (const auto &use_operator : use_operators) - { - if (executed.find(use_operator) == executed.end()) - { - dead_buffer = false; - break; - } - } - - if (dead_buffer) - { - _env->freeIfAllocated(input_index); - } - } - } - } -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/Interpreter.h b/runtime/onert/core/src/interp/Interpreter.h deleted file mode 100644 index d2165f538..000000000 --- a/runtime/onert/core/src/interp/Interpreter.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file Interpreter.h - * @brief This file contains Interpreter class for interpretation - */ -#ifndef __ONERT_INTERP_INTERPRETER_H__ -#define __ONERT_INTERP_INTERPRETER_H__ - -#include "ExecEnv.h" - -namespace onert -{ -namespace interp -{ - -/** - * @brief Class for interpretation - */ -class Interpreter -{ - -public: - /** - * @brief Construct a new Interpreter object (deleted) - */ - Interpreter() = delete; - /** - * @brief Construct a new Interpreter object - * @param[in] env Execution environment variable for interpreter object - */ - Interpreter(std::unique_ptr<ExecEnv> env) : _env{std::move(env)} - { - // DO NOTHING - } - -public: - /** - * @brief Run interpreter until there is no operation to execute - */ - void run(); - -private: - std::unique_ptr<ExecEnv> _env; -}; - -} // namespace interp -} // namespace onert - -#endif // __ONERT_INTERP_INTERPRETER_H__ diff --git a/runtime/onert/core/src/interp/Registration.h b/runtime/onert/core/src/interp/Registration.h deleted file mode 100644 index 956b92a53..000000000 --- a/runtime/onert/core/src/interp/Registration.h +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __ONERT_INTERP_REGISTRATION_H__ -#define __ONERT_INTERP_REGISTRATION_H__ - -#include "ExecEnv.h" - -#include "ir/Operation.h" - -namespace onert -{ -namespace interp -{ - -struct OpKernel -{ - std::function<void(ExecEnv *, const ir::Operation &)> prepare; - std::function<void(const ExecEnv *, const ir::Operation &)> invoke; -}; - -// Defined in operations/ directory -#define INTERP_OP(InternalName) OpKernel *get##InternalName(); -#include "InterpOps.lst" -#undef INTERP_OP - -} // namespace interp -} // namespace onert - -#endif // __ONERT_INTERP_REGISTRATION_H__ diff --git a/runtime/onert/core/src/interp/Tensor.cc b/runtime/onert/core/src/interp/Tensor.cc deleted file mode 100644 index 07f8b75dc..000000000 --- a/runtime/onert/core/src/interp/Tensor.cc +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "Tensor.h" - -#define NO_USE(a) (void)(a) - -namespace onert -{ -namespace interp -{ - -void ITensor::access(const std::function<void(backend::ITensor &tensor)> &fn) { fn(*this); } - -size_t ROTensor::calcOffset(const ir::Coordinates &coords) const -{ - NO_USE(coords); - throw std::runtime_error("offset_element_in_bytes is not supported for cpu::Tensor now."); -} - -size_t Tensor::calcOffset(const ir::Coordinates &coords) const -{ - NO_USE(coords); - throw std::runtime_error("offset_element_in_bytes is not supported for cpu::Tensor now."); -} - -ir::Layout ROTensor::layout() const -{ - // TODO Changes to return frontend layout - return ir::Layout::NHWC; -} - -ir::Layout Tensor::layout() const -{ - // TODO Changes to return frontend layout - return ir::Layout::NHWC; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/Tensor.h b/runtime/onert/core/src/interp/Tensor.h deleted file mode 100644 index 008a4b9d4..000000000 --- a/runtime/onert/core/src/interp/Tensor.h +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file Tensor.h - * @brief This file contains ITensor interface, ROTensor class, and Tensor class - */ -#ifndef __ONERT_INTERP_TENSOR_H__ -#define __ONERT_INTERP_TENSOR_H__ - -#include "Buffer.h" - -#include "ir/OperandInfo.h" -#include "backend/ITensor.h" -#include "ir/Layout.h" - -namespace onert -{ -namespace interp -{ - -/** - * @brief Interface to handle Tensor in interpreter - */ -class ITensor : public backend::ITensor -{ -public: - virtual ~ITensor() = default; - -public: - virtual uint8_t *buffer() const = 0; - /** - * @brief Return shared pointer for buffer - * @return Buffer shared pointer - */ - virtual std::shared_ptr<const Buffer> shareBuffer() const = 0; - /** - * @brief Return read-only buffer pointer - * @return Read-only buffer pointer - */ - virtual const uint8_t *bufferRO() const = 0; - /** - * @brief Return shared pointer for data - * @return Data shared pointer - */ - virtual std::shared_ptr<const ir::Data> shareData() const = 0; - /** - * @brief Set internal/external buffer - * @param[in] buffer Buffer pointer - */ - virtual void setBuffer(std::shared_ptr<const Buffer> buffer) = 0; - /** - * @brief Set data reference (including constant, input) - * @param[in] data Data pointer - */ - virtual void setData(std::shared_ptr<const ir::Data> data) = 0; - virtual void releaseData() = 0; - - virtual size_t total_size() const = 0; - virtual size_t dimension(size_t index) const = 0; - virtual size_t num_dimensions() const = 0; - virtual size_t calcOffset(const ir::Coordinates &coords) const = 0; - - virtual bool has_padding() const = 0; - /** - * @brief Return data type of tensor - * @return Data type of tensor - */ - virtual ir::DataType data_type() const = 0; - /** - * @brief Return TensorInfo - * @return TensorInfo - */ - virtual const ir::OperandInfo &tensorInfo() const = 0; - /** - * @brief Return number of elements - * @return Number of elements - */ - virtual uint64_t num_elements() const = 0; - void access(const std::function<void(backend::ITensor &tensor)> &fn) final; -}; - -/** - * @brief Class to handle tensor in interpreter as read-only - */ -class ROTensor final : public ITensor -{ -public: - ROTensor() = delete; - ROTensor(const ir::OperandInfo &info) : _info(info) - { - // DO NOTHING - } - -public: - uint8_t *buffer() const override { throw std::runtime_error{"Read only tensor"}; } - std::shared_ptr<const Buffer> shareBuffer() const override - { - throw std::runtime_error{"Read only tensor"}; - } - const uint8_t *bufferRO() const override { return _data->base(); } - std::shared_ptr<const ir::Data> shareData() const override { return _data; } - void setBuffer(std::shared_ptr<const Buffer> buffer) override { _data = buffer; } - void setData(std::shared_ptr<const ir::Data> data) override { _data = data; } - void releaseData() override { _data = nullptr; } - - size_t total_size() const override { return _info.total_size(); } - size_t dimension(size_t index) const override { return _info.shape().dim(index); } - size_t num_dimensions() const override { return _info.shape().rank(); } - size_t calcOffset(const ir::Coordinates &coords) const override; - ir::Layout layout() const override; - bool is_dynamic() const override { return false; } - bool has_padding() const override { return false; } - ir::DataType data_type() const override { return _info.typeInfo().type(); } - float data_scale() const override { return _info.typeInfo().scale(); } - int32_t data_offset() const override { return _info.typeInfo().offset(); } - const ir::OperandInfo &tensorInfo() const override { return _info; } - uint64_t num_elements() const override { return _info.shape().num_elements(); }; - -private: - const ir::OperandInfo _info; - std::shared_ptr<const ir::Data> _data{nullptr}; -}; - -/** - * @brief Class to handle tensor in interpreter as writable - */ -class Tensor final : public ITensor -{ -public: - Tensor() = delete; - Tensor(const ir::OperandInfo &info) : _info(info) - { - // DO NOTHING - } - -public: - uint8_t *buffer() const override { return _buffer->baseWritable(); } - std::shared_ptr<const Buffer> shareBuffer() const override { return _buffer; }; - const uint8_t *bufferRO() const override { return _buffer->base(); } - std::shared_ptr<const ir::Data> shareData() const override { return _buffer; } - void setBuffer(std::shared_ptr<const Buffer> buffer) override { _buffer = buffer; } - void setData(std::shared_ptr<const ir::Data>) override - { - throw std::runtime_error{"Passed data may read-only"}; - } - void releaseData() override { _buffer = nullptr; } - - size_t total_size() const override { return _info.total_size(); } - size_t dimension(size_t index) const override { return _info.shape().dim(index); } - size_t num_dimensions() const override { return _info.shape().rank(); } - size_t calcOffset(const ir::Coordinates &coords) const override; - ir::Layout layout() const override; - bool is_dynamic() const override { return false; } - bool has_padding() const override { return false; } - ir::DataType data_type() const override { return _info.typeInfo().type(); } - float data_scale() const override { return _info.typeInfo().scale(); } - int32_t data_offset() const override { return _info.typeInfo().offset(); } - const ir::OperandInfo &tensorInfo() const override { return _info; } - uint64_t num_elements() const override { return _info.shape().num_elements(); }; - backend::IDynamicTensorManager *dynamic_tensor_manager() override { return nullptr; } - -private: - const ir::OperandInfo _info; - std::shared_ptr<const Buffer> _buffer{nullptr}; -}; - -} // namespace interp -} // namespace onert - -#endif // __ONERT_INTERP_TENSOR_H__ diff --git a/runtime/onert/core/src/interp/operations/BinaryArithmeticOps.cc b/runtime/onert/core/src/interp/operations/BinaryArithmeticOps.cc deleted file mode 100644 index 86e883524..000000000 --- a/runtime/onert/core/src/interp/operations/BinaryArithmeticOps.cc +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <cker/operation/BinaryArithmeticOps.h> - -#include "OperationUtil.h" - -#include "interp/Registration.h" -#include "ir/operation/BinaryArithmetic.h" -#include "misc/polymorphic_downcast.h" -#include "cker/Types.h" - -namespace onert -{ -namespace interp -{ -namespace -{ - -enum class OpType -{ - ADD, - SUB, - MUL -}; - -void prepare(ExecEnv *env, const ir::Operation &node) -{ - const auto &arithmetic_node = - nnfw::misc::polymorphic_downcast<const ir::operation::BinaryArithmetic &>(node); - - const auto lhs_index = node.getInputs().at(arithmetic_node.LHS); - const auto rhs_index = node.getInputs().at(arithmetic_node.RHS); - const auto out_index = node.getOutputs().at(0); - - const auto lhs_tensor = env->tensorAt(lhs_index); - const auto rhs_tensor = env->tensorAt(rhs_index); - - // Check shape and type lhs is same with rhs - // TODO Util function to compare TensorInfo - if (lhs_tensor->data_type() != rhs_tensor->data_type()) - { - throw std::runtime_error{"Interp(" + arithmetic_node.name() + "): Different input types"}; - } - - bool try_broadcast = (lhs_tensor->tensorInfo().shape() != rhs_tensor->tensorInfo().shape()); - if (try_broadcast) - { - bool success = true; - auto out_shape = calcBroadcastShape(lhs_tensor->tensorInfo().shape(), - rhs_tensor->tensorInfo().shape(), success); - if (!success) - { - throw std::runtime_error{"Interp(" + arithmetic_node.name() + "): Fail to brodcasting"}; - } - - auto output_info = - ir::OperandInfo::createStaticInfo(out_shape, lhs_tensor->tensorInfo().typeInfo()); - // We can handle already allocated (ex. model output) - env->allocateIfNeeded(out_index, output_info); - } - else - { - // Output's shape and type is same with input - auto output_info = lhs_tensor->tensorInfo(); - // We can handle already allocated (ex. model output) - env->allocateIfNeeded(out_index, output_info); - } - - auto out_tensor = env->tensorAt(out_index); - // Check shape and type lhs is same with output - // TODO Util function to compare TensorInfo - if (lhs_tensor->data_type() != out_tensor->data_type()) - { - throw std::runtime_error{"Interp(" + arithmetic_node.name() + "): Invalid output type"}; - } -} - -inline void setActivationParams(float min, float max, nnfw::cker::BinaryArithmeticOpParam *params) -{ - params->float_activation_min = min; - params->float_activation_max = max; -} - -inline void setActivationParams(int32_t min, int32_t max, - nnfw::cker::BinaryArithmeticOpParam *params) -{ - params->quantized_activation_min = min; - params->quantized_activation_max = max; -} - -template <typename raw_type, OpType op_type> -void invoke(const ITensor *lhs_tensor, const ITensor *rhs_tensor, const ITensor *out_tensor, - const ir::operation::BinaryArithmetic::Param ¶m) -{ - const auto lhs_buffer = lhs_tensor->bufferRO(); - const auto rhs_buffer = rhs_tensor->bufferRO(); - auto out_buffer = out_tensor->buffer(); - - nnfw::cker::BinaryArithmeticOpParam cker_param; - raw_type activation_min, activation_max; - calculateActivationRange(param.activation, &activation_min, &activation_max); - setActivationParams(activation_min, activation_max, &cker_param); - const raw_type *lhs_ptr = reinterpret_cast<const raw_type *>(lhs_buffer); - const raw_type *rhs_ptr = reinterpret_cast<const raw_type *>(rhs_buffer); - raw_type *out_ptr = reinterpret_cast<raw_type *>(out_buffer); - - const auto cker_op_type = - (op_type == OpType::ADD) - ? nnfw::cker::BinaryArithmeticOpType::ADD - : ((op_type == OpType::SUB) ? nnfw::cker::BinaryArithmeticOpType::SUB - : nnfw::cker::BinaryArithmeticOpType::MUL); - - const bool need_broadcast = nnfw::cker::ProcessBroadcastShapes( - convertShape(lhs_tensor->tensorInfo().shape()), - convertShape(rhs_tensor->tensorInfo().shape()), &cker_param); - - if (need_broadcast) - { - const auto lhs_shape = convertShape(lhs_tensor->tensorInfo().shape()); - const auto rhs_shape = convertShape(rhs_tensor->tensorInfo().shape()); - const auto out_shape = convertShape(out_tensor->tensorInfo().shape()); - nnfw::cker::BroadcastBinaryArithmeticOp<cker_op_type>(cker_param, lhs_shape, lhs_ptr, rhs_shape, - rhs_ptr, out_shape, out_ptr); - return; - } - - const auto lhs_shape = convertShape(lhs_tensor->tensorInfo().shape()); - const auto rhs_shape = convertShape(rhs_tensor->tensorInfo().shape()); - const auto out_shape = convertShape(out_tensor->tensorInfo().shape()); - nnfw::cker::BinaryArithmeticOp<cker_op_type>(cker_param, lhs_shape, lhs_ptr, rhs_shape, rhs_ptr, - out_shape, out_ptr); -} - -template <OpType op_type> -void invokeBinaryArithmetic(const ExecEnv *env, const ir::operation::BinaryArithmetic &node) -{ - const auto lhs_index = node.getInputs().at(node.LHS); - const auto rhs_index = node.getInputs().at(node.RHS); - const auto out_index = node.getOutputs().at(0); - const auto lhs_tensor = env->tensorAt(lhs_index); - const auto rhs_tensor = env->tensorAt(rhs_index); - const auto out_tensor = env->tensorAt(out_index); - const auto data_type = lhs_tensor->data_type(); - - if (data_type == ir::DataType::INT32) - { - invoke<int32_t, op_type>(lhs_tensor, rhs_tensor, out_tensor, node.param()); - } - else if (data_type == ir::DataType::FLOAT32) - { - invoke<float, op_type>(lhs_tensor, rhs_tensor, out_tensor, node.param()); - } - else - { - throw std::runtime_error{"NYI: Unsupported data type"}; - } -} - -void invokeBinaryArithmeticOps(const ExecEnv *env, const ir::Operation &node) -{ - const auto &arithmetic_node = - nnfw::misc::polymorphic_downcast<const ir::operation::BinaryArithmetic &>(node); - - switch (arithmetic_node.param().arithmetic_type) - { - case ir::operation::BinaryArithmetic::ArithmeticType::ADD: - invokeBinaryArithmetic<OpType::ADD>(env, arithmetic_node); - break; - case ir::operation::BinaryArithmetic::ArithmeticType::SUB: - invokeBinaryArithmetic<OpType::SUB>(env, arithmetic_node); - break; - case ir::operation::BinaryArithmetic::ArithmeticType::MUL: - invokeBinaryArithmetic<OpType::MUL>(env, arithmetic_node); - break; - default: - throw std::runtime_error{"Interp(BinaryArithmetic): NYI unsupported operation " + - arithmetic_node.name()}; - break; - } -} - -} // namespace - -OpKernel *getBinaryArithmetic() -{ - static OpKernel kernel = {prepare, invokeBinaryArithmeticOps}; - return &kernel; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/operations/Concat.cc b/runtime/onert/core/src/interp/operations/Concat.cc deleted file mode 100644 index efc46c66b..000000000 --- a/runtime/onert/core/src/interp/operations/Concat.cc +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <cker/operation/Concatenation.h> - -#include "OperationUtil.h" - -#include "interp/Registration.h" -#include "ir/operation/Concat.h" -#include "misc/polymorphic_downcast.h" - -namespace onert -{ -namespace interp -{ -namespace concat -{ - -void prepareConcat(ExecEnv *env, const ir::Operation &node) -{ - const auto &concat_node = nnfw::misc::polymorphic_downcast<const ir::operation::Concat &>(node); - - const auto first_index = node.getInputs().at(0); - const auto out_index = node.getOutputs().at(0); - - const auto first_tensor = env->tensorAt(first_index); - uint32_t out_axis_dimension = 0; - const int32_t axis_raw = concat_node.param().axis; - const uint32_t axis = (axis_raw < 0) ? (axis_raw + first_tensor->num_dimensions()) : axis_raw; - - // All inputs shape should be same except axis dimension - // All inputs type should be same - for (auto input : node.getInputs()) - { - assert(first_tensor->num_dimensions() == env->tensorAt(input)->num_dimensions()); - assert(first_tensor->data_type() == env->tensorAt(input)->data_type()); - for (uint32_t i = 0; i < first_tensor->num_dimensions(); i++) - { - if (i == axis) - { - out_axis_dimension += env->tensorAt(input)->dimension(i); - continue; - } - assert(first_tensor->dimension(i) == env->tensorAt(input)->dimension(i)); - } - } - - // Make output tensor info using first input tensor info, and accumulated axis dimension value - auto out_shape = first_tensor->tensorInfo().shape(); - out_shape.dim(axis) = out_axis_dimension; - env->allocateIfNeeded(out_index, ir::OperandInfo::createStaticInfo( - out_shape, first_tensor->tensorInfo().typeInfo())); - - auto out_tensor = env->tensorAt(out_index); - UNUSED_RELEASE(out_tensor); - - // Output shape should be same with input except axis dimension - // Output type should be same with input - assert(first_tensor->data_type() == out_tensor->data_type()); - for (uint32_t i = 0; i < first_tensor->num_dimensions(); i++) - { - if (i == axis) - { - continue; - } - assert(first_tensor->dimension(i) == out_tensor->dimension(i)); - } -} - -void invoke(const std::vector<const ITensor *> in_tensors, const ITensor *out_tensor, uint32_t axis) -{ - const uint32_t count = in_tensors.size(); - - // Calculate - nnfw::cker::ConcatenationParams cker_param; - cker_param.axis = (int8_t)axis; - cker_param.inputs_count = count; - - const auto out_shape = convertShape(out_tensor->tensorInfo().shape()); - - std::vector<nnfw::cker::Shape> in_shapes; - std::vector<const nnfw::cker::Shape *> in_shape_ptrs; - in_shapes.reserve(count); - in_shape_ptrs.reserve(count); - std::vector<const float *> in_ptrs; - for (uint32_t i = 0; i < count; i++) - { - in_shapes.push_back(convertShape(in_tensors[i]->tensorInfo().shape())); - in_shape_ptrs.push_back(&in_shapes[i]); - in_ptrs.push_back(reinterpret_cast<const float *>(in_tensors[i]->bufferRO())); - } - - auto out_buffer = out_tensor->buffer(); - float *out_ptr = reinterpret_cast<float *>(out_buffer); - - nnfw::cker::Concatenation<float>(cker_param, in_shape_ptrs.data(), in_ptrs.data(), out_shape, - out_ptr); -} - -void invokeConcat(const ExecEnv *env, const ir::Operation &node) -{ - const auto &concat_node = nnfw::misc::polymorphic_downcast<const ir::operation::Concat &>(node); - const int32_t axis_raw = concat_node.param().axis; - - std::vector<const ITensor *> in_tensors; - for (const auto &e : concat_node.getInputs()) - { - in_tensors.emplace_back(env->tensorAt(e)); - } - - const auto out_index = node.getOutputs().at(0); - const auto out_tensor = env->tensorAt(out_index); - const uint32_t axis = (axis_raw < 0) ? (axis_raw + out_tensor->num_dimensions()) : axis_raw; - - const auto data_type = in_tensors[0]->data_type(); - if (data_type == ir::DataType::FLOAT32) - { - invoke(in_tensors, out_tensor, axis); - } - else - { - throw std::runtime_error{"NYI: Support float32 only"}; - } -} -} // namespace concat - -OpKernel *getConcat() -{ - static OpKernel kernel = {concat::prepareConcat, concat::invokeConcat}; - return &kernel; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/operations/Conv2D.cc b/runtime/onert/core/src/interp/operations/Conv2D.cc deleted file mode 100644 index bb00b828c..000000000 --- a/runtime/onert/core/src/interp/operations/Conv2D.cc +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <cker/operation/Conv.h> - -#include "OperationUtil.h" - -#include "interp/Registration.h" -#include "ir/operation/Conv2D.h" -#include "util/Utils.h" -#include "util/ShapeInference.h" -#include "misc/polymorphic_downcast.h" - -namespace onert -{ -namespace interp -{ -namespace conv2d -{ - -void prepareConv2D(ExecEnv *env, const ir::Operation &node) -{ - const auto in_index = node.getInputs().at(ir::operation::Conv2D::INPUT); - const auto kernel_index = node.getInputs().at(ir::operation::Conv2D::KERNEL); - const auto bias_index = node.getInputs().at(ir::operation::Conv2D::BIAS); - const auto out_index = node.getOutputs().at(0); - - const auto in_tensor = env->tensorAt(in_index); - const auto kernel_tensor = env->tensorAt(kernel_index); - const auto bias_tensor = env->tensorAt(bias_index); - - assert(in_tensor->num_dimensions() == 4); - assert(kernel_tensor->num_dimensions() == 4); - assert(bias_tensor->num_dimensions() == 1); - - UNUSED_RELEASE(in_tensor); - UNUSED_RELEASE(kernel_tensor); - UNUSED_RELEASE(bias_tensor); - - const auto output_info = env->graph().operands().at(out_index).info(); - if (output_info.total_size() == 0) - { - // Handle unspecified output shape - const auto &conv_node = nnfw::misc::polymorphic_downcast<const ir::operation::Conv2D &>(node); - const auto infered_output_shape = shape_inference::inferConv2DShape( - in_tensor->tensorInfo().shape(), kernel_tensor->tensorInfo().shape(), conv_node.param()); - env->allocateIfNeeded( - out_index, ir::OperandInfo::createStaticInfo(infered_output_shape, output_info.typeInfo())); - } - else - { - env->allocateIfNeeded(out_index, output_info); - } - - auto out_tensor = env->tensorAt(out_index); - UNUSED_RELEASE(out_tensor); - - // Handle same ifm & ofm data type only - assert(in_tensor->data_type() == out_tensor->data_type()); - assert(out_tensor->num_dimensions() == 4); -} - -void invoke(const ITensor *ifm_tensor, const ITensor *ker_tensor, const ITensor *bias_tensor, - const ITensor *ofm_tensor, const ir::operation::Conv2D::Param ¶m) -{ - // TODO Support NCHW frontned - const auto ifm_shape = ifm_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC); - const auto ofm_shape = ofm_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC); - // Kernel format is [depth_out, kernel_height, kernel_width, depth_in]. - const auto &ker_shape = ker_tensor->tensorInfo().shape(); - const auto ker_height = ker_shape.dim(1); - const auto ker_width = ker_shape.dim(2); - const auto padding = ir::calculatePadding(param.padding, ifm_shape, ofm_shape, param.stride, - ker_width, ker_height); - - // Calculate - float activation_min, activation_max; - calculateActivationRange(param.activation, &activation_min, &activation_max); - - nnfw::cker::ConvParams cker_param; - cker_param.padding_type = convertPaddingType(param.padding.type); - cker_param.padding_values.width = padding.left; - cker_param.padding_values.height = padding.top; - cker_param.stride_width = param.stride.horizontal; - cker_param.stride_height = param.stride.vertical; - cker_param.dilation_width_factor = 1; - cker_param.dilation_height_factor = 1; - cker_param.float_activation_min = activation_min; - cker_param.float_activation_max = activation_max; - - const auto cker_ifm_shape = convertShape(ifm_tensor->tensorInfo().shape()); - const auto cker_ker_shape = convertShape(ker_tensor->tensorInfo().shape()); - const auto cker_bias_shape = convertShape(bias_tensor->tensorInfo().shape()); - const auto cker_ofm_shape = convertShape(ofm_tensor->tensorInfo().shape()); - const float *ifm_ptr = reinterpret_cast<const float *>(ifm_tensor->bufferRO()); - const float *ker_ptr = reinterpret_cast<const float *>(ker_tensor->bufferRO()); - const float *bias_ptr = reinterpret_cast<const float *>(bias_tensor->bufferRO()); - float *ofm_ptr = reinterpret_cast<float *>(ofm_tensor->buffer()); - - nnfw::cker::Conv conv_kernel; - conv_kernel(cker_param, cker_ifm_shape, ifm_ptr, cker_ker_shape, ker_ptr, cker_bias_shape, - bias_ptr, cker_ofm_shape, ofm_ptr); -} - -void invokeConv2D(const ExecEnv *env, const ir::Operation &node) -{ - const auto &conv_node = nnfw::misc::polymorphic_downcast<const ir::operation::Conv2D &>(node); - - const auto ifm_index = node.getInputs().at(ir::operation::Conv2D::INPUT); - const auto ker_index = node.getInputs().at(ir::operation::Conv2D::KERNEL); - const auto bias_index = node.getInputs().at(ir::operation::Conv2D::BIAS); - const auto ofm_index = node.getOutputs().at(0); - - const auto ifm_tensor = env->tensorAt(ifm_index); - const auto ker_tensor = env->tensorAt(ker_index); - const auto bias_tensor = env->tensorAt(bias_index); - const auto ofm_tensor = env->tensorAt(ofm_index); - - const auto data_type = ifm_tensor->data_type(); - if (data_type == ir::DataType::FLOAT32) - { - invoke(ifm_tensor, ker_tensor, bias_tensor, ofm_tensor, conv_node.param()); - } - else - { - throw std::runtime_error{"NYI: Support float32 only"}; - } -} -} // namespace conv2d - -OpKernel *getConv2D() -{ - static OpKernel kernel = {conv2d::prepareConv2D, conv2d::invokeConv2D}; - return &kernel; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/operations/DepthwiseConv2D.cc b/runtime/onert/core/src/interp/operations/DepthwiseConv2D.cc deleted file mode 100644 index 0473855d9..000000000 --- a/runtime/onert/core/src/interp/operations/DepthwiseConv2D.cc +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <cker/operation/DepthwiseConv.h> -#include <misc/polymorphic_downcast.h> - -#include "OperationUtil.h" - -#include "interp/Registration.h" -#include "ir/operation/DepthwiseConv2D.h" -#include "util/Utils.h" -#include "util/ShapeInference.h" - -namespace onert -{ -namespace interp -{ - -namespace -{ - -void prepareDepthwiseConv(ExecEnv *env, const ir::Operation &node) -{ - const auto in_index = node.getInputs().at(ir::operation::DepthwiseConv2D::INPUT); - const auto kernel_index = node.getInputs().at(ir::operation::DepthwiseConv2D::KERNEL); - const auto bias_index = node.getInputs().at(ir::operation::DepthwiseConv2D::BIAS); - const auto out_index = node.getOutputs().at(0); - - const auto in_tensor = env->tensorAt(in_index); - const auto kernel_tensor = env->tensorAt(kernel_index); - const auto bias_tensor = env->tensorAt(bias_index); - - assert(in_tensor->num_dimensions() == 4); - assert(kernel_tensor->num_dimensions() == 4); - assert(bias_tensor->num_dimensions() == 1); - - UNUSED_RELEASE(in_tensor); - UNUSED_RELEASE(kernel_tensor); - UNUSED_RELEASE(bias_tensor); - - // TODO handle unspecified output shape: - // calculate output shape using ifm shape, kernel shape, padding, stride - const auto output_info = env->graph().operands().at(out_index).info(); - if (output_info.total_size() == 0) - { - // Handle unspecified output shape - const auto &depth_conv_node = - nnfw::misc::polymorphic_downcast<const ir::operation::DepthwiseConv2D &>(node); - const auto infered_output_shape = shape_inference::inferDepthwiseConv2DShape( - in_tensor->tensorInfo().shape(), kernel_tensor->tensorInfo().shape(), - depth_conv_node.param()); - env->allocateIfNeeded( - out_index, ir::OperandInfo::createStaticInfo(infered_output_shape, output_info.typeInfo())); - } - else - { - env->allocateIfNeeded(out_index, output_info); - } - - auto out_tensor = env->tensorAt(out_index); - UNUSED_RELEASE(out_tensor); - - // Handle same ifm & ofm data type only - assert(in_tensor->data_type() == out_tensor->data_type()); - assert(out_tensor->num_dimensions() == 4); -} - -void invoke(const ITensor *ifm_tensor, const ITensor *ker_tensor, const ITensor *bias_tensor, - const ITensor *ofm_tensor, const ir::operation::DepthwiseConv2D::Param ¶m) -{ - // TODO Support NCHW frontend - const auto ifm_shape = ifm_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC); - const auto ofm_shape = ofm_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC); - // Kernel format is [1, kernel_height, kernel_width, depth_out]. - const auto &ker_shape = ker_tensor->tensorInfo().shape(); - const auto ker_height = ker_shape.dim(1); - const auto ker_width = ker_shape.dim(2); - const auto padding = ir::calculatePadding(param.padding, ifm_shape, ofm_shape, param.stride, - ker_width, ker_height); - - // Calculate - float activation_min, activation_max; - calculateActivationRange(param.activation, &activation_min, &activation_max); - - nnfw::cker::DepthwiseConvParams cker_param; - cker_param.padding_values.width = padding.left; - cker_param.padding_values.height = padding.top; - cker_param.depth_multiplier = param.multiplier; - cker_param.stride_width = param.stride.horizontal; - cker_param.stride_height = param.stride.vertical; - cker_param.dilation_width_factor = 1; - cker_param.dilation_height_factor = 1; - cker_param.float_activation_min = activation_min; - cker_param.float_activation_max = activation_max; - - const auto cker_ifm_shape = convertShape(ifm_tensor->tensorInfo().shape()); - const auto cker_ker_shape = convertShape(ker_tensor->tensorInfo().shape()); - const auto cker_bias_shape = convertShape(bias_tensor->tensorInfo().shape()); - const auto cker_ofm_shape = convertShape(ofm_tensor->tensorInfo().shape()); - const float *ifm_ptr = reinterpret_cast<const float *>(ifm_tensor->bufferRO()); - const float *ker_ptr = reinterpret_cast<const float *>(ker_tensor->bufferRO()); - const float *bias_ptr = reinterpret_cast<const float *>(bias_tensor->bufferRO()); - float *ofm_ptr = reinterpret_cast<float *>(ofm_tensor->buffer()); - - nnfw::cker::DepthwiseConv(cker_param, cker_ifm_shape, ifm_ptr, cker_ker_shape, ker_ptr, - cker_bias_shape, bias_ptr, cker_ofm_shape, ofm_ptr); -} - -void invokeDepthwiseConv(const ExecEnv *env, const ir::Operation &node) -{ - const auto &conv_node = static_cast<const ir::operation::DepthwiseConv2D &>(node); - - const auto ifm_index = node.getInputs().at(ir::operation::DepthwiseConv2D::INPUT); - const auto ker_index = node.getInputs().at(ir::operation::DepthwiseConv2D::KERNEL); - const auto bias_index = node.getInputs().at(ir::operation::DepthwiseConv2D::BIAS); - const auto ofm_index = node.getOutputs().at(0); - - const auto ifm_tensor = env->tensorAt(ifm_index); - const auto ker_tensor = env->tensorAt(ker_index); - const auto bias_tensor = env->tensorAt(bias_index); - const auto ofm_tensor = env->tensorAt(ofm_index); - - const auto data_type = ifm_tensor->data_type(); - if (data_type == ir::DataType::FLOAT32) - { - invoke(ifm_tensor, ker_tensor, bias_tensor, ofm_tensor, conv_node.param()); - } - else - { - throw std::runtime_error{"NYI: Support float32 only"}; - } -} - -} // namespace - -OpKernel *getDepthwiseConv2D() -{ - static OpKernel kernel = {prepareDepthwiseConv, invokeDepthwiseConv}; - return &kernel; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/operations/ElementwiseActivations.cc b/runtime/onert/core/src/interp/operations/ElementwiseActivations.cc deleted file mode 100644 index c8773bef4..000000000 --- a/runtime/onert/core/src/interp/operations/ElementwiseActivations.cc +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <cmath> - -#include "OperationUtil.h" - -#include "interp/Registration.h" - -#include "ir/operation/ElementwiseActivation.h" - -#include <misc/polymorphic_downcast.h> -#include <cker/operation/Logistic.h> -#include <cker/operation/Tanh.h> - -namespace onert -{ -namespace interp -{ -namespace -{ - -enum class ActivationType -{ - Logistic, - ReLU, - Tanh -}; - -void prepare(ExecEnv *env, const ir::Operation &node) -{ - const auto input_index = node.getInputs().at(0); - const auto output_index = node.getOutputs().at(0); - - const auto input_tensor = env->tensorAt(input_index); - - const auto output_info = env->graph().operands().at(output_index).info(); - if (output_info.total_size() == 0) - { - // Output's shape and type is same with input - auto input_info = input_tensor->tensorInfo(); - // We can handle already allocated (ex. model output) - env->allocateIfNeeded(output_index, input_info); - } - else - { - env->allocateIfNeeded(output_index, output_info); - } - - const auto output_tensor = env->tensorAt(output_index); - // Check shape and type lhs is same with output - // TODO Util function to compare TensorInfo - if (input_tensor->data_type() != output_tensor->data_type()) - { - throw std::runtime_error{"Interp(ElementwiseActivation): Invalid output type"}; - } -} - -template <ActivationType act_type> -void evalFloat(const float *input_ptr, float *output_ptr, uint64_t num_elements, float alpha, - float beta) -{ - std::function<float(const float &)> fn = [](const float &) { return std::nanf(""); }; - switch (act_type) - { - case ActivationType::ReLU: - fn = [alpha, beta](const float &in) { return std::min(std::max(beta, in), alpha); }; - break; - case ActivationType::Tanh: - fn = [](const float &in) { return std::tanh(in); }; - break; - default: - throw std::runtime_error{"Interp(ElementwiseActivation): NYI - Unsupported activation"}; - break; - } - - const float *input_end = input_ptr + num_elements; - for (; input_ptr < input_end; input_ptr++, output_ptr++) - { - *output_ptr = fn(*input_ptr); - } -} - -template <ActivationType act_type> void invoke(const ExecEnv *env, const ir::Operation &node) -{ - const auto input_index = node.getInputs().at(0); - const auto output_index = node.getOutputs().at(0); - - // Check lhs shape is same with rhs (with broadcast) - const auto input_tensor = env->tensorAt(input_index); - const auto output_tensor = env->tensorAt(output_index); - - const auto data_type = input_tensor->data_type(); - if (data_type == ir::DataType::FLOAT32) - { - uint64_t elements = input_tensor->num_elements(); - const float *input_start = reinterpret_cast<const float *>(input_tensor->bufferRO()); - float *out = reinterpret_cast<float *>(output_tensor->buffer()); - if (act_type == ActivationType::Logistic) - { - const auto cker_input_shape = convertShape(input_tensor->tensorInfo().shape()); - const auto cker_output_shape = convertShape(output_tensor->tensorInfo().shape()); - nnfw::cker::Logistic(cker_input_shape, input_start, cker_output_shape, out); - } - else - { - const auto &act_node = - nnfw::misc::polymorphic_downcast<const ir::operation::ElementwiseActivation &>(node); - evalFloat<act_type>(input_start, out, elements, act_node.param().alpha, - act_node.param().beta); - } - } - else - { - throw std::runtime_error{"Interp(" + node.name() + "): NYI - Support float only"}; - } -} - -void invokeElementwiseActivation(const ExecEnv *env, const ir::Operation &node) -{ - const auto &act_node = - nnfw::misc::polymorphic_downcast<const ir::operation::ElementwiseActivation &>(node); - switch (act_node.param().op_type) - { - case ir::operation::ElementwiseActivation::Type::LOGISTIC: - invoke<ActivationType::Logistic>(env, node); - break; - case ir::operation::ElementwiseActivation::Type::RELU: - invoke<ActivationType::ReLU>(env, node); - break; - case ir::operation::ElementwiseActivation::Type::TANH: - invoke<ActivationType::Tanh>(env, node); - break; - default: - throw std::runtime_error("Interp(" + node.name() + "): NYI - Unsupported activation"); - } -} - -} // namespace - -OpKernel *getElementwiseActivation() -{ - static OpKernel kernel = {prepare, invokeElementwiseActivation}; - return &kernel; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/operations/FullyConnected.cc b/runtime/onert/core/src/interp/operations/FullyConnected.cc deleted file mode 100644 index 12f529dab..000000000 --- a/runtime/onert/core/src/interp/operations/FullyConnected.cc +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <cker/operation/FullyConnected.h> - -#include "OperationUtil.h" - -#include "interp/Registration.h" -#include "ir/operation/FullyConnected.h" -#include "misc/polymorphic_downcast.h" - -namespace onert -{ -namespace interp -{ -namespace fc -{ - -void prepareFC(ExecEnv *env, const ir::Operation &node) -{ - const auto in_index = node.getInputs().at(ir::operation::FullyConnected::INPUT); - const auto kernel_index = node.getInputs().at(ir::operation::FullyConnected::WEIGHT); - const auto bias_index = node.getInputs().at(ir::operation::FullyConnected::BIAS); - const auto out_index = node.getOutputs().at(0); - - const auto in_tensor = env->tensorAt(in_index); - const auto kernel_tensor = env->tensorAt(kernel_index); - const auto bias_tensor = env->tensorAt(bias_index); - - UNUSED_RELEASE(in_tensor); - UNUSED_RELEASE(kernel_tensor); - UNUSED_RELEASE(bias_tensor); - - assert(in_tensor->num_dimensions() >= 2); - assert(kernel_tensor->num_dimensions() == 2); - assert(bias_tensor->num_dimensions() == 1); - - const auto input_size_with_batch = in_tensor->num_elements(); - const auto num_units = kernel_tensor->dimension(0); - const auto input_size = kernel_tensor->dimension(1); - const auto batch_size = input_size_with_batch / input_size; - assert(input_size_with_batch % input_size == 0); - assert(num_units == bias_tensor->dimension(0)); - - // Make output tensor info - ir::Shape output_shape(2); - output_shape.dim(0) = batch_size; - output_shape.dim(1) = num_units; - const auto out_info = - ir::OperandInfo::createStaticInfo(output_shape, in_tensor->tensorInfo().typeInfo()); - env->allocateIfNeeded(out_index, out_info); - - auto out_tensor = env->tensorAt(out_index); - UNUSED_RELEASE(out_tensor); - - // Handle same ifm & ofm data type only - assert(in_tensor->data_type() == out_tensor->data_type()); - assert(out_tensor->num_dimensions() == 2); - assert(out_tensor->dimension(0) == batch_size); - assert(out_tensor->dimension(1) == num_units); -} - -void invoke(const ITensor *ifm_tensor, const ITensor *ker_tensor, const ITensor *bias_tensor, - const ITensor *ofm_tensor, const ir::operation::FullyConnected::Param ¶m) -{ - const auto ifm_buffer = ifm_tensor->bufferRO(); - const auto ker_buffer = ker_tensor->bufferRO(); - const auto bias_buffer = bias_tensor->bufferRO(); - auto ofm_buffer = ofm_tensor->buffer(); - - // Calculate - nnfw::cker::FullyConnectedParams cker_param; - cker_param.activation = convertActivationType(param.activation); - calculateActivationRange(param.activation, &cker_param.float_activation_min, - &cker_param.float_activation_max); - const auto cker_ifm_shape = convertShape(ifm_tensor->tensorInfo().shape()); - const auto cker_ker_shape = convertShape(ker_tensor->tensorInfo().shape()); - const auto cker_bias_shape = convertShape(bias_tensor->tensorInfo().shape()); - const auto cker_ofm_shape = convertShape(ofm_tensor->tensorInfo().shape()); - const float *ifm_ptr = reinterpret_cast<const float *>(ifm_buffer); - const float *ker_ptr = reinterpret_cast<const float *>(ker_buffer); - const float *bias_ptr = reinterpret_cast<const float *>(bias_buffer); - float *ofm_ptr = reinterpret_cast<float *>(ofm_buffer); - - nnfw::cker::FullyConnected(cker_param, cker_ifm_shape, ifm_ptr, cker_ker_shape, ker_ptr, - cker_bias_shape, bias_ptr, cker_ofm_shape, ofm_ptr); -} - -void invokeFC(const ExecEnv *env, const ir::Operation &node) -{ - const auto &conv_node = - nnfw::misc::polymorphic_downcast<const ir::operation::FullyConnected &>(node); - - const auto ifm_index = node.getInputs().at(ir::operation::FullyConnected::INPUT); - const auto ker_index = node.getInputs().at(ir::operation::FullyConnected::WEIGHT); - const auto bias_index = node.getInputs().at(ir::operation::FullyConnected::BIAS); - const auto ofm_index = node.getOutputs().at(0); - - const auto ifm_tensor = env->tensorAt(ifm_index); - const auto ker_tensor = env->tensorAt(ker_index); - const auto bias_tensor = env->tensorAt(bias_index); - const auto ofm_tensor = env->tensorAt(ofm_index); - - const auto data_type = ifm_tensor->data_type(); - if (data_type == ir::DataType::FLOAT32) - { - invoke(ifm_tensor, ker_tensor, bias_tensor, ofm_tensor, conv_node.param()); - } - else - { - throw std::runtime_error{"NYI: Support float only"}; - } -} -} // namespace fc - -OpKernel *getFullyConnected() -{ - static OpKernel kernel = {fc::prepareFC, fc::invokeFC}; - return &kernel; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/operations/Gather.cc b/runtime/onert/core/src/interp/operations/Gather.cc deleted file mode 100644 index 9e82def5f..000000000 --- a/runtime/onert/core/src/interp/operations/Gather.cc +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <cker/operation/Gather.h> - -#include "OperationUtil.h" - -#include "interp/Registration.h" -#include "ir/operation/Gather.h" -#include "misc/polymorphic_downcast.h" - -namespace onert -{ -namespace interp -{ -namespace -{ - -void prepareGather(ExecEnv *env, const ir::Operation &node) -{ - const auto input_index = node.getInputs().at(ir::operation::Gather::INPUT); - const auto indices_index = node.getInputs().at(ir::operation::Gather::INDICES); - const auto output_index = node.getOutputs().at(0); - - const auto input_tensor = env->tensorAt(input_index); - const auto indices_tensor = env->tensorAt(indices_index); - - // TODO handle unspecified output shape: - // calculate output shape using ifm shape, kernel shape, padding, stride - const auto output_info = env->graph().operands().at(output_index).info(); - if (output_info.total_size() == 0) - { - throw std::runtime_error{"Interp(Gather): NYI for unspecified output shape"}; - } - else - { - env->allocateIfNeeded(output_index, output_info); - } - - if (indices_tensor->data_type() != ir::DataType::INT32) - { - throw std::runtime_error{"Interp(Gather): Invalid indices data type"}; - } - - auto output_tensor = env->tensorAt(output_index); - auto output_rank = input_tensor->num_dimensions() + indices_tensor->num_dimensions() - 1; - - if (output_rank != output_tensor->num_dimensions()) - { - throw std::runtime_error{"Interp(Gather): Invalid output rank"}; - } - if (output_tensor->data_type() != input_tensor->data_type()) - { - throw std::runtime_error{"Interp(Gather): Invalid output data type"}; - } - - if (input_tensor->data_type() == ir::DataType::QUANT_UINT8_ASYMM && - input_tensor->tensorInfo().typeInfo() != output_tensor->tensorInfo().typeInfo()) - { - throw std::runtime_error{ - "Interp(Gather): Cannot handle different I/O QUANT_UINT8_ASYMM scale/offset"}; - } -} - -template <typename raw_type> -void invoke(const ITensor *input_tensors, const ITensor *indices_tensors, - const ITensor *output_tensor, uint32_t axis) -{ - // Calculate - nnfw::cker::GatherParams cker_param; - cker_param.axis = (int8_t)axis; - - const auto cker_input_shapes = convertShape(input_tensors->tensorInfo().shape()); - const auto cker_indices_shape = convertShape(indices_tensors->tensorInfo().shape()); - const auto cker_output_shape = convertShape(output_tensor->tensorInfo().shape()); - const raw_type *input_ptr = reinterpret_cast<const raw_type *>(input_tensors->bufferRO()); - const int32_t *indices_ptr = reinterpret_cast<const int32_t *>(indices_tensors->bufferRO()); - raw_type *output_ptr = reinterpret_cast<raw_type *>(output_tensor->buffer()); - - nnfw::cker::Gather<raw_type>(cker_param, cker_input_shapes, input_ptr, cker_indices_shape, - indices_ptr, cker_output_shape, output_ptr); -} - -void invokeGather(const ExecEnv *env, const ir::Operation &node) -{ - const auto &gather_node = nnfw::misc::polymorphic_downcast<const ir::operation::Gather &>(node); - const int32_t axis_raw = gather_node.param().axis; - - const auto input_index = node.getInputs().at(ir::operation::Gather::INPUT); - const auto indices_index = node.getInputs().at(ir::operation::Gather::INDICES); - const auto output_index = node.getOutputs().at(0); - - const auto input_tensor = env->tensorAt(input_index); - const auto indices_tensor = env->tensorAt(indices_index); - const auto output_tensor = env->tensorAt(output_index); - const uint32_t axis = (axis_raw < 0) ? (axis_raw + input_tensor->num_dimensions()) : axis_raw; - - const auto data_type = input_tensor->data_type(); - - switch (data_type) - { - case ir::DataType::FLOAT32: - invoke<float>(input_tensor, indices_tensor, output_tensor, axis); - break; - case ir::DataType::INT32: - invoke<int32_t>(input_tensor, indices_tensor, output_tensor, axis); - break; - case ir::DataType::QUANT_UINT8_ASYMM: - invoke<uint8_t>(input_tensor, indices_tensor, output_tensor, axis); - break; - default: - throw std::runtime_error{"Interp(Gather): NYI - Not supported type"}; - } -} - -} // namespace - -OpKernel *getGather() -{ - static OpKernel kernel = {prepareGather, invokeGather}; - return &kernel; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/operations/InstanceNorm.cc b/runtime/onert/core/src/interp/operations/InstanceNorm.cc deleted file mode 100644 index 2538bcc39..000000000 --- a/runtime/onert/core/src/interp/operations/InstanceNorm.cc +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <cker/operation/InstanceNorm.h> - -#include "OperationUtil.h" - -#include "interp/Registration.h" -#include "ir/operation/InstanceNorm.h" -#include "misc/polymorphic_downcast.h" - -namespace onert -{ -namespace interp -{ -namespace instancenorm -{ - -void prepareInstanceNorm(ExecEnv *env, const ir::Operation &node) -{ - const auto &instancenorm_node = - nnfw::misc::polymorphic_downcast<const ir::operation::InstanceNorm &>(node); - - const auto input_index = node.getInputs().at(instancenorm_node.INPUT); - const auto output_index = node.getOutputs().at(0); - const auto input_tensor = env->tensorAt(input_index); - - if (input_tensor->num_dimensions() != 4) - { - throw std::runtime_error{"Interp(InstanceNorm): Input should be 4D-tensor"}; - } - - // Output shape should be same with input - env->allocateIfNeeded(output_index, input_tensor->tensorInfo()); - - auto output_tensor = env->tensorAt(output_index); - UNUSED_RELEASE(output_tensor); - - // Handle same ifm & ofm data type only - assert(input_tensor->data_type() == output_tensor->data_type()); - assert(input_tensor->tensorInfo().shape() == output_tensor->tensorInfo().shape()); -} - -inline void setActivationParams(float min, float max, nnfw::cker::InstanceNormParams *params) -{ - params->float_activation_min = min; - params->float_activation_max = max; -} - -void invoke(const ITensor *input_tensor, const ITensor *gamma_tensor, const ITensor *beta_tensor, - const ITensor *output_tensor, const ir::operation::InstanceNorm::Param ¶m) -{ - // Calculate - float activation_min, activation_max; - calculateActivationRange(param.activation, &activation_min, &activation_max); - - nnfw::cker::InstanceNormParams cker_param; - cker_param.epsilon = param.epsilon; - cker_param.float_activation_min = activation_min; - cker_param.float_activation_max = activation_max; - - const auto cker_input_shape = convertShape(input_tensor->tensorInfo().shape()); - const auto cker_gamma_shape = convertShape(gamma_tensor->tensorInfo().shape()); - const auto cker_beta_shape = convertShape(beta_tensor->tensorInfo().shape()); - const auto cker_output_shape = convertShape(output_tensor->tensorInfo().shape()); - const float *input_ptr = reinterpret_cast<const float *>(input_tensor->bufferRO()); - const float *gamma_ptr = reinterpret_cast<const float *>(gamma_tensor->bufferRO()); - const float *beta_ptr = reinterpret_cast<const float *>(beta_tensor->bufferRO()); - float *output_ptr = reinterpret_cast<float *>(output_tensor->buffer()); - - nnfw::cker::InstanceNorm(cker_param, cker_input_shape, input_ptr, cker_gamma_shape, gamma_ptr, - cker_beta_shape, beta_ptr, cker_output_shape, output_ptr); -} - -void invokeInstanceNorm(const ExecEnv *env, const ir::Operation &node) -{ - const auto &instancenorm_node = - nnfw::misc::polymorphic_downcast<const ir::operation::InstanceNorm &>(node); - - const auto input_index = node.getInputs().at(instancenorm_node.INPUT); - const auto gamma_index = node.getInputs().at(instancenorm_node.GAMMA); - const auto beta_index = node.getInputs().at(instancenorm_node.BETA); - const auto out_index = node.getOutputs().at(0); - const auto input_tensor = env->tensorAt(input_index); - const auto gamma_tensor = env->tensorAt(gamma_index); - const auto beta_tensor = env->tensorAt(beta_index); - const auto out_tensor = env->tensorAt(out_index); - const auto data_type = input_tensor->data_type(); - - if (data_type == ir::DataType::FLOAT32) - { - invoke(input_tensor, gamma_tensor, beta_tensor, out_tensor, instancenorm_node.param()); - } - else - { - throw std::runtime_error{"NYI: Unsupported data type"}; - } -} -} // namespace instancenorm - -OpKernel *getInstanceNorm() -{ - static OpKernel kernel = {instancenorm::prepareInstanceNorm, instancenorm::invokeInstanceNorm}; - return &kernel; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/operations/OperationUtil.h b/runtime/onert/core/src/interp/operations/OperationUtil.h deleted file mode 100644 index 2fdf098f0..000000000 --- a/runtime/onert/core/src/interp/operations/OperationUtil.h +++ /dev/null @@ -1,203 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __ONERT_INTERP_OPERATIONS_OPERATION_UTILS_H_ -#define __ONERT_INTERP_OPERATIONS_OPERATION_UTILS_H_ - -#include "ir/Shape.h" -#include "ir/InternalType.h" -#include "ir/Padding.h" - -#include <cker/Shape.h> -#include <cker/Types.h> - -namespace onert -{ -namespace interp -{ - -inline nnfw::cker::Shape convertShape(const ir::Shape &shape) -{ - auto dimensions = std::vector<uint32_t>(shape.dims().begin(), shape.dims().end()); - - std::vector<int32_t> raw_shape; - raw_shape.resize(dimensions.size()); - - for (uint32_t i = 0; i < dimensions.size(); ++i) - { - raw_shape[i] = dimensions[i]; - } - - return nnfw::cker::GetShape(raw_shape); -} - -inline nnfw::cker::Shape convertExtendShape(const ir::Shape &shape) -{ - auto dimensions = std::vector<uint32_t>(shape.dims().begin(), shape.dims().end()); - - const int32_t extended_rank = 4; - int32_t raw_shape[extended_rank]; - uint32_t start = extended_rank - dimensions.size(); - - for (uint32_t i = 0; i < extended_rank; ++i) - { - if (i < start) - { - raw_shape[i] = 1; - } - else - { - raw_shape[i] = dimensions[i - start]; - } - } - - return nnfw::cker::Shape(extended_rank, raw_shape); -} - -inline nnfw::cker::FusedActivationFunctionType -convertActivationType(const ir::Activation activation) -{ - switch (activation) - { - case ir::Activation::NONE: - return nnfw::cker::FusedActivationFunctionType::kNone; - case ir::Activation::RELU: - return nnfw::cker::FusedActivationFunctionType::kRelu; - case ir::Activation::RELU1: - return nnfw::cker::FusedActivationFunctionType::kRelu1; - case ir::Activation::RELU6: - return nnfw::cker::FusedActivationFunctionType::kRelu6; - default: - throw std::runtime_error{"CPU backend: Cannot convert activation type"}; - } -} - -template <typename T> -void calculateActivationRange(ir::Activation activation, T *activation_min, T *activation_max) -{ - if (activation == ir::Activation::RELU) - { - *activation_min = 0; - *activation_max = std::numeric_limits<T>::max(); - } - else if (activation == ir::Activation::RELU6) - { - *activation_min = 0; - *activation_max = 6; - } - else if (activation == ir::Activation::RELU1) - { - *activation_min = -1; - *activation_max = 1; - } - else if (activation == ir::Activation::NONE) - { - *activation_min = std::numeric_limits<T>::lowest(); - *activation_max = std::numeric_limits<T>::max(); - } - else - { - throw std::runtime_error{"Unsupported activation type"}; - } -} - -inline ir::Shape calcBroadcastShape(const ir::Shape &lhs, const ir::Shape &rhs, bool &success) -{ - int lhs_rank = lhs.rank(); - int rhs_rank = rhs.rank(); - - int out_rank = (lhs_rank > rhs_rank ? lhs_rank : rhs_rank); - ir::Shape out_shape(out_rank); - - int lhs_idim = lhs_rank - 1; - int rhs_idim = rhs_rank - 1; - success = true; - for (int out_idim = out_rank - 1; out_idim >= 0; out_idim--) - { - if (lhs_idim == -1 && rhs_idim == -1) - { - // invalid result - success = false; - break; - } - - if (lhs_idim == -1) - { - out_shape.dim(out_idim) = rhs.dim(rhs_idim); - rhs_idim--; - } - else if (rhs_idim == -1) - { - out_shape.dim(out_idim) = lhs.dim(lhs_idim); - lhs_idim--; - } - else - { - if (lhs.dim(lhs_idim) == rhs.dim(rhs_idim)) - { - out_shape.dim(out_idim) = lhs.dim(lhs_idim); - lhs_idim--; - rhs_idim--; - } - else if (lhs.dim(lhs_idim) == 1) - { - out_shape.dim(out_idim) = rhs.dim(rhs_idim); - lhs_idim--; - rhs_idim--; - } - else if (rhs.dim(rhs_idim) == 1) - { - out_shape.dim(out_idim) = lhs.dim(lhs_idim); - lhs_idim--; - rhs_idim--; - } - else - { - // invalid result - success = false; - break; - } - } - } - - if (lhs_idim != -1 || rhs_idim != -1) - { - // invalid result - success = false; - } - return out_shape; -} - -inline nnfw::cker::PaddingType convertPaddingType(ir::PaddingType ir_padding_type) -{ - switch (ir_padding_type) - { - case ir::PaddingType::EXPLICIT: - return nnfw::cker::PaddingType::kNone; - case ir::PaddingType::SAME: - return nnfw::cker::PaddingType::kSame; - case ir::PaddingType::VALID: - return nnfw::cker::PaddingType::kValid; - default: - throw std::runtime_error("Wrong padding type."); - break; - } -} - -} // namespace interp -} // namespace onert - -#endif // __ONERT_INTERP_OPERATIONS_OPERATION_UTILS_H_ diff --git a/runtime/onert/core/src/interp/operations/Pad.cc b/runtime/onert/core/src/interp/operations/Pad.cc deleted file mode 100644 index c8dce698d..000000000 --- a/runtime/onert/core/src/interp/operations/Pad.cc +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <cker/operation/Pad.h> - -#include "OperationUtil.h" - -#include "interp/Registration.h" -#include "ir/operation/Pad.h" - -namespace onert -{ -namespace interp -{ -namespace -{ - -void preparePad(ExecEnv *env, const ir::Operation &node) -{ - const auto input_index = node.getInputs().at(ir::operation::Pad::INPUT); - const auto output_index = node.getOutputs().at(0); - - const auto input_tensor = env->tensorAt(input_index); - - const auto output_info = env->graph().operands().at(output_index).info(); - - // Check shape and type lhs is same with rhs - // TODO Util function to compare TensorInfo - if (output_info.total_size() == 0) - { - throw std::runtime_error{"Interp(Pad): NYI unspecified output shape"}; - } - else - { - env->allocateIfNeeded(output_index, output_info); - } - - const auto output_tensor = env->tensorAt(output_index); - if (input_tensor->data_type() != output_tensor->data_type()) - { - throw std::runtime_error{"Interp(Pad): Invalid output type"}; - } -} - -void invoke(const ITensor *input_tensor, const ITensor *pad_tensor, const ITensor *output_tensor) -{ - const auto input_buffer = input_tensor->bufferRO(); - const auto pad_buffer = pad_tensor->bufferRO(); - auto output_buffer = output_tensor->buffer(); - - int32_t pad_rank = pad_tensor->dimension(0); - - const auto cker_input_shape = convertShape(input_tensor->tensorInfo().shape()); - const auto cker_output_shape = convertShape(output_tensor->tensorInfo().shape()); - const float *input_ptr = reinterpret_cast<const float *>(input_buffer); - const int32_t *pad_ptr = reinterpret_cast<const int32_t *>(pad_buffer); - float *output_ptr = reinterpret_cast<float *>(output_buffer); - - nnfw::cker::Pad<float>(pad_ptr, pad_rank, cker_input_shape, input_ptr, cker_output_shape, - output_ptr, nullptr); -} - -void invokePad(const ExecEnv *env, const ir::Operation &node) -{ - const auto input_index = node.getInputs().at(ir::operation::Pad::INPUT); - const auto pad_index = node.getInputs().at(ir::operation::Pad::PAD); - const auto output_index = node.getOutputs().at(0); - - const auto input_tensor = env->tensorAt(input_index); - const auto pad_tensor = env->tensorAt(pad_index); - const auto output_tensor = env->tensorAt(output_index); - - const auto data_type = input_tensor->data_type(); - - if (data_type == ir::DataType::FLOAT32) - { - invoke(input_tensor, pad_tensor, output_tensor); - } - else - { - throw std::runtime_error{"Interp(Pad): NYI - Unsupported data type"}; - } -} -} // namespace - -OpKernel *getPad() -{ - static OpKernel kernel = {preparePad, invokePad}; - return &kernel; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/operations/Pool2D.cc b/runtime/onert/core/src/interp/operations/Pool2D.cc deleted file mode 100644 index 92f9d70b2..000000000 --- a/runtime/onert/core/src/interp/operations/Pool2D.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <cker/operation/AveragePool.h> -#include <cker/operation/MaxPool.h> - -#include "OperationUtil.h" - -#include "interp/Registration.h" -#include "ir/operation/Pool2D.h" -#include "util/Utils.h" -#include "util/ShapeInference.h" -#include "misc/polymorphic_downcast.h" - -namespace onert -{ -namespace interp -{ -namespace pool2d -{ - -void preparePool2D(ExecEnv *env, const ir::Operation &node) -{ - const auto &pool_node = nnfw::misc::polymorphic_downcast<const ir::operation::Pool2D &>(node); - const auto in_index = node.getInputs().at(pool_node.INPUT); - const auto out_index = node.getOutputs().at(0); - - const auto in_tensor = env->tensorAt(in_index); - UNUSED_RELEASE(in_tensor); - - assert(in_tensor->num_dimensions() == 4); - - const auto output_info = env->graph().operands().at(out_index).info(); - if (output_info.total_size() == 0) - { - // Handle unspecified output shape - const auto infered_output_shape = - shape_inference::inferPoolShape(in_tensor->tensorInfo().shape(), pool_node.param()); - env->allocateIfNeeded( - out_index, ir::OperandInfo::createStaticInfo(infered_output_shape, output_info.typeInfo())); - } - else - { - env->allocateIfNeeded(out_index, output_info); - } - - auto out_tensor = env->tensorAt(out_index); - UNUSED_RELEASE(out_tensor); - - // Handle same ifm & ofm data type only - assert(in_tensor->data_type() == out_tensor->data_type()); - assert(out_tensor->num_dimensions() == 4); -} - -template <typename T> -void invoke(const nnfw::cker::PoolParams ¶ms, const nnfw::cker::Shape &in_shape, - const T *in_ptr, const nnfw::cker::Shape &out_shape, T *out_ptr, - ir::operation::Pool2D::PoolType op_type) -{ - switch (op_type) - { - case ir::operation::Pool2D::PoolType::AVG: - nnfw::cker::AveragePool<T>(params, in_shape, in_ptr, out_shape, out_ptr); - break; - case ir::operation::Pool2D::PoolType::MAX: - nnfw::cker::MaxPool<T>(params, in_shape, in_ptr, out_shape, out_ptr); - break; - default: - throw std::runtime_error{"Interp(Pool2D): NYI unsupported operation"}; - break; - } -} - -void invokePool2DOps(const ExecEnv *env, const ir::Operation &node) -{ - const auto &pool_node = nnfw::misc::polymorphic_downcast<const ir::operation::Pool2D &>(node); - - const auto in_index = node.getInputs().at(0); - const auto out_index = node.getOutputs().at(0); - - // Check lhs shape is same with rhs (with broadcast) - const auto in_tensor = env->tensorAt(in_index); - const auto out_tensor = env->tensorAt(out_index); - - // TODO support NCHW frontend - const auto ifm_shape = in_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC); - const auto ofm_shape = out_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC); - const auto param = pool_node.param(); - const auto padding = - ir::calculatePadding(param.padding, ifm_shape, ofm_shape, param.stride, param.kw, param.kh); - // Calculate - nnfw::cker::PoolParams cker_param; - cker_param.filter_width = param.kw; - cker_param.filter_height = param.kh; - cker_param.padding_values.width = padding.left; - cker_param.padding_values.height = padding.top; - cker_param.stride_width = param.stride.horizontal; - cker_param.stride_height = param.stride.vertical; - - const auto data_type = in_tensor->data_type(); - if (data_type == ir::DataType::FLOAT32) - { - calculateActivationRange(param.activation, &cker_param.float_activation_min, - &cker_param.float_activation_max); - - const auto in_shape = convertShape(in_tensor->tensorInfo().shape()); - const auto out_shape = convertShape(out_tensor->tensorInfo().shape()); - const float *in_ptr = reinterpret_cast<const float *>(in_tensor->bufferRO()); - float *out_ptr = reinterpret_cast<float *>(out_tensor->buffer()); - // Now, invoke() supports only Pool2D in float - invoke<float>(cker_param, in_shape, in_ptr, out_shape, out_ptr, param.op_type); - } - else - { - throw std::runtime_error{"NYI: Support float only"}; - } -} -} // namespace pool2d - -OpKernel *getPool2D() -{ - static OpKernel kernel = {pool2d::preparePool2D, pool2d::invokePool2DOps}; - return &kernel; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/operations/Reshape.cc b/runtime/onert/core/src/interp/operations/Reshape.cc deleted file mode 100644 index 3a118456b..000000000 --- a/runtime/onert/core/src/interp/operations/Reshape.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "interp/Registration.h" - -namespace onert -{ -namespace interp -{ -namespace -{ - -void prepare(ExecEnv *env, const ir::Operation &node) -{ - const auto in_index = node.getInputs().at(0); - const auto out_index = node.getOutputs().at(0); - - // Unspecified shape is not supported in operation node spec now - const auto output_info = env->graph().operands().at(out_index).info(); - env->allocateAndShareIfNeeded(out_index, output_info, in_index); - - assert(output_info.total_size() == env->graph().operands().at(in_index).info().total_size()); -} - -void invoke(const ExecEnv *env, const ir::Operation &node) -{ - const auto in_index = node.getInputs().at(0); - const auto out_index = node.getOutputs().at(0); - - if (env->tensorAt(in_index)->bufferRO() == env->tensorAt(out_index)->bufferRO()) - { - // Same data - return; - } - - const auto output_info = env->graph().operands().at(out_index).info(); - memcpy(env->tensorAt(out_index)->buffer(), env->tensorAt(in_index)->bufferRO(), - output_info.total_size()); -} - -} // namespace - -OpKernel *getReshape() -{ - static OpKernel kernel = {prepare, invoke}; - return &kernel; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/operations/Softmax.cc b/runtime/onert/core/src/interp/operations/Softmax.cc deleted file mode 100644 index d30f78deb..000000000 --- a/runtime/onert/core/src/interp/operations/Softmax.cc +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <cker/operation/SoftMax.h> - -#include "OperationUtil.h" - -#include "interp/Registration.h" -#include "ir/operation/Softmax.h" -#include "misc/polymorphic_downcast.h" - -namespace onert -{ -namespace interp -{ -namespace -{ - -void prepareSoftMax(ExecEnv *env, const ir::Operation &node) -{ - const auto in_index = node.getInputs().at(0); - const auto out_index = node.getOutputs().at(0); - - const auto in_tensor = env->tensorAt(in_index); - UNUSED_RELEASE(in_tensor); - - assert((in_tensor->num_dimensions() == 4) || (in_tensor->num_dimensions() == 2)); - - // Output shape should be same with input - // Output type is pre-defined in model - const auto output_shape = env->graph().operands().at(in_index).info().shape(); - const auto output_type = env->graph().operands().at(out_index).info().typeInfo(); - - const auto output_info = ir::OperandInfo::createStaticInfo(output_shape, output_type); - env->allocateIfNeeded(out_index, output_info); - - auto out_tensor = env->tensorAt(out_index); - UNUSED_RELEASE(out_tensor); - - // Check output shape is same with input - assert(out_tensor->num_dimensions() == out_tensor->num_dimensions()); - for (uint32_t i = 0; i < in_tensor->num_dimensions(); i++) - { - assert(in_tensor->dimension(i) == out_tensor->dimension(i)); - } -} - -void invoke(const ITensor *in_tensor, const ITensor *out_tensor, - const ir::operation::Softmax::Param ¶m) -{ - const float *in_ptr = reinterpret_cast<const float *>(in_tensor->bufferRO()); - float *out_ptr = reinterpret_cast<float *>(out_tensor->buffer()); - - float beta = param.beta; - - if (in_tensor->num_dimensions() == 2) - { - uint32_t batch_size = in_tensor->dimension(0); - uint32_t input_size = in_tensor->dimension(1); - - nnfw::cker::Softmax(in_ptr, input_size, batch_size, beta, out_ptr); - } - else if (in_tensor->num_dimensions() == 4) - { - const auto in_shape = convertShape(in_tensor->tensorInfo().shape()); - const auto out_shape = convertShape(out_tensor->tensorInfo().shape()); - - nnfw::cker::SoftmaxParams cker_param; - cker_param.beta = beta; - - nnfw::cker::Softmax(cker_param, in_shape, in_ptr, out_shape, out_ptr); - } - else - { - throw std::runtime_error{"Unsuported input dimension: support 2D or 4D"}; - } -} - -void invokeSoftMax(const ExecEnv *env, const ir::Operation &node) -{ - const auto &softmax_node = nnfw::misc::polymorphic_downcast<const ir::operation::Softmax &>(node); - - const auto in_index = node.getInputs().at(0); - const auto out_index = node.getOutputs().at(0); - - const auto in_tensor = env->tensorAt(in_index); - const auto out_tensor = env->tensorAt(out_index); - - const auto in_data_type = in_tensor->data_type(); - const auto out_data_type = out_tensor->data_type(); - if ((in_data_type == ir::DataType::FLOAT32) && (out_data_type == ir::DataType::FLOAT32)) - { - invoke(in_tensor, out_tensor, softmax_node.param()); - } - else - { - throw std::runtime_error{"NYI: Support float32 only"}; - } -} - -} // namespace - -OpKernel *getSoftmax() -{ - static OpKernel kernel = {prepareSoftMax, invokeSoftMax}; - return &kernel; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/interp/operations/TransposeConv.cc b/runtime/onert/core/src/interp/operations/TransposeConv.cc deleted file mode 100644 index cc2ced26b..000000000 --- a/runtime/onert/core/src/interp/operations/TransposeConv.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <cker/operation/TransposeConv.h> -#include <misc/polymorphic_downcast.h> - -#include "OperationUtil.h" - -#include "interp/Registration.h" -#include "ir/operation/TransposeConv.h" - -namespace onert -{ -namespace interp -{ -namespace -{ - -void prepareTransposeConv(ExecEnv *env, const ir::Operation &node) -{ - const auto ifm_index = node.getInputs().at(ir::operation::TransposeConv::INPUT); - const auto ker_index = node.getInputs().at(ir::operation::TransposeConv::KERNEL); - const auto ofm_shape_index = node.getInputs().at(ir::operation::TransposeConv::OUTPUT_SHAPE); - const auto ofm_index = node.getOutputs().at(0); - - const auto ifm_tensor = env->tensorAt(ifm_index); - const auto ker_tensor = env->tensorAt(ker_index); - const auto ofm_shape_tensor = env->tensorAt(ofm_shape_index); - - assert(ifm_tensor->num_dimensions() == 4); - assert(ker_tensor->num_dimensions() == 4); - assert(ofm_shape_tensor->num_dimensions() == 1); - - UNUSED_RELEASE(ifm_tensor); - UNUSED_RELEASE(ker_tensor); - UNUSED_RELEASE(ofm_shape_tensor); - - const auto output_info = env->graph().operands().at(ofm_index).info(); - if (output_info.total_size() == 0) - { - // TODO: Handle unspecified output shape - throw std::runtime_error{"Interp(TConv): NYI unspecified output shape"}; - } - else - { - env->allocateIfNeeded(ofm_index, output_info); - } - - auto ofm_tensor = env->tensorAt(ofm_index); - UNUSED_RELEASE(ofm_tensor); - - // Handle same ifm & ofm data type only - if (ifm_tensor->data_type() != ofm_tensor->data_type()) - { - throw std::runtime_error{"Interp(TConv): Different I/O data dype"}; - } - - if (ofm_tensor->num_dimensions() != 4) - { - throw std::runtime_error{"Interp(TConv): Invalid output rank"}; - } -} - -void invoke(const ITensor *ifm_tensor, const ITensor *ker_tensor, const ITensor *ofm_tensor, - const ir::operation::TransposeConv::Param ¶m) -{ - const auto ifm_shape = ifm_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC); - const auto ofm_shape = ofm_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC); - // Kernel format is [depth_out, kernel_height, kernel_width, depth_in]. - const auto ker_shape = ker_tensor->tensorInfo().shape(); - const auto ker_height = ker_shape.dim(1); - const auto ker_width = ker_shape.dim(2); - const auto padding = ir::calculatePadding(param.padding, ofm_shape, ifm_shape, param.stride, - ker_width, ker_height); - - nnfw::cker::TransposeConvParams cker_param; - cker_param.padding_values.width = padding.left; - cker_param.padding_values.height = padding.top; - cker_param.stride_width = param.stride.horizontal; - cker_param.stride_height = param.stride.vertical; - cker_param.dilation_width_factor = 1; - cker_param.dilation_height_factor = 1; - - const auto cker_ifm_shape = convertShape(ifm_tensor->tensorInfo().shape()); - const auto cker_ker_shape = convertShape(ker_tensor->tensorInfo().shape()); - const auto cker_ofm_shape = convertShape(ofm_tensor->tensorInfo().shape()); - const float *ifm_ptr = reinterpret_cast<const float *>(ifm_tensor->bufferRO()); - const float *ker_ptr = reinterpret_cast<const float *>(ker_tensor->bufferRO()); - float *ofm_ptr = reinterpret_cast<float *>(ofm_tensor->buffer()); - - nnfw::cker::TransposeConv(cker_param, cker_ifm_shape, ifm_ptr, cker_ker_shape, ker_ptr, - cker_ofm_shape, ofm_ptr); -} - -void invokeTransposeConv(const ExecEnv *env, const ir::Operation &node) -{ - const auto &tconv_node = - nnfw::misc::polymorphic_downcast<const ir::operation::TransposeConv &>(node); - - const auto ifm_index = node.getInputs().at(ir::operation::TransposeConv::INPUT); - const auto ker_index = node.getInputs().at(ir::operation::TransposeConv::KERNEL); - const auto ofm_index = node.getOutputs().at(0); - - const auto ifm_tensor = env->tensorAt(ifm_index); - const auto ker_tensor = env->tensorAt(ker_index); - const auto ofm_tensor = env->tensorAt(ofm_index); - - const auto data_type = ifm_tensor->data_type(); - if (data_type == ir::DataType::FLOAT32) - { - invoke(ifm_tensor, ker_tensor, ofm_tensor, tconv_node.param()); - } - else - { - throw std::runtime_error{"Interp(TConv): Support float32 only"}; - } -} - -} // namespace - -OpKernel *getTransposeConv() -{ - static OpKernel kernel = {prepareTransposeConv, invokeTransposeConv}; - return &kernel; -} - -} // namespace interp -} // namespace onert diff --git a/runtime/onert/core/src/ir/DataType.cc b/runtime/onert/core/src/ir/DataType.cc index 80c659b3a..07670c720 100644 --- a/runtime/onert/core/src/ir/DataType.cc +++ b/runtime/onert/core/src/ir/DataType.cc @@ -41,11 +41,17 @@ size_t sizeOfDataType(DataType data_type) case DataType::UINT8: return sizeof(uint8_t); case DataType::QUANT_INT8_SYMM: + case DataType::QUANT_INT8_ASYMM: + case DataType::QUANT_INT8_SYMM_PER_CHANNEL: return sizeof(int8_t); case DataType::FLOAT16: return sizeof(float16); case DataType::INT64: return sizeof(int64_t); + case DataType::QUANT_INT16_ASYMM: + return sizeof(int16_t); + case DataType::QUANT_INT16_SYMM: + return sizeof(int16_t); default: throw std::runtime_error{"Unsupported type size"}; } diff --git a/runtime/onert/core/src/ir/Graph.cc b/runtime/onert/core/src/ir/Graph.cc index fe8b1b443..306572c99 100644 --- a/runtime/onert/core/src/ir/Graph.cc +++ b/runtime/onert/core/src/ir/Graph.cc @@ -16,18 +16,10 @@ #include "ir/Graph.h" -#include <algorithm> -#include <bitset> -#include <sstream> - -#include "util/logging.h" +#include "OperationValidator.h" #include "verifier/Verifier.h" -#include "ir/operation/LowerInfo.h" -#include "ir/operand/LowerInfo.h" -#include "ir/operand/PermuteFactor.h" -#include "ir/OperandIndexMap.h" -#include "ir/GraphIterator.h" -#include "backend/IConfig.h" + +#include "util/Set.h" namespace onert { @@ -36,6 +28,8 @@ namespace ir Graph::Graph() = default; +Graph::Graph(const Graph &) = default; + Graph::~Graph(void) = default; OperandIndex Graph::addOperand(const Shape &shape, const TypeInfo &type) @@ -43,22 +37,91 @@ OperandIndex Graph::addOperand(const Shape &shape, const TypeInfo &type) return _operands.emplace(shape, type); } -OperationIndex Graph::addOperation(std::unique_ptr<Operation> &&node) +OperandIndex Graph::addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand) +{ + return _operands.push(std::move(operand), index); +} + +bool Graph::checkOperandsForOperation(const IOperation &operation) { - assert(isBuildingPhase()); - return _operations.push(std::move(node)); + auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + for (auto &&input : inputs) + if (!operands().exist(input)) + return false; + for (auto &&input : outputs) + if (!operands().exist(input)) + return false; + return true; +} + +void Graph::linkOperandToOperation(OperationIndex index, const IOperation &operation) +{ + auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + + for (auto &&input : inputs) + operands().at(input).insertUse(index); + for (auto &&output : outputs) + operands().at(output).setDef(index); +} + +OperationIndex Graph::addOperation(std::unique_ptr<IOperation> &&operation) +{ + const IOperation &op_ref = *operation; + if (!checkOperandsForOperation(op_ref)) + return OperationIndex{}; + auto ind = _operations.push(std::move(operation)); + if (ind.valid()) + linkOperandToOperation(ind, op_ref); + return ind; +} + +OperationIndex Graph::addOperation(OperationIndex index, std::unique_ptr<IOperation> &&operation) +{ + const IOperation &op_ref = *operation; + if (!checkOperandsForOperation(op_ref)) + return OperationIndex{}; + auto ind_gen = _operations.push(std::move(operation), index); + if (ind_gen.valid()) + { + assert(ind_gen == index); + linkOperandToOperation(index, op_ref); + } + return index; +} + +OperationIndex Graph::replaceOperation(OperationIndex index, + std::unique_ptr<IOperation> &&operation) +{ + const IOperation &op_ref = *operation; + if (!checkOperandsForOperation(op_ref) || !_operations.exist(index)) + return OperationIndex{}; + + // Check the new operation has the same inputs/outputs as the existing operation + const auto &old_op = _operations.at(index); + if (!(old_op.getInputs() == op_ref.getInputs() && old_op.getOutputs() == op_ref.getOutputs())) + { + return OperationIndex{}; + } + + return _operations.set(index, std::move(operation)); } void Graph::setOperandValue(const OperandIndex &ind, std::shared_ptr<Data> data) { - assert(isBuildingPhase()); assert(_operands.exist(ind)); _operands.at(ind).data(std::move(data)); } +void Graph::changeShape(const OperandIndex &ind, const ir::Shape &new_shape) +{ + assert(_operands.exist(ind)); + _operands.at(ind).info().shape(new_shape); +} + void Graph::addInput(const OperandIndex &ind, const std::string &name) { - assert(isBuildingPhase()); if (!name.empty()) _name_to_input.emplace(name, IOIndex{_inputs.size()}); _inputs.append(ind); @@ -66,7 +129,6 @@ void Graph::addInput(const OperandIndex &ind, const std::string &name) void Graph::addOutput(const OperandIndex &ind, const std::string &name) { - assert(isBuildingPhase()); if (!name.empty()) _name_to_output.emplace(name, IOIndex{_outputs.size()}); _outputs.append(ind); @@ -84,62 +146,70 @@ IOIndex Graph::getOutputIndex(const std::string &name) const return (itr == _name_to_output.end()) ? IOIndex{} : itr->second; } -void Graph::finishBuilding(void) +void Graph::verify(void) const { - assert(isBuildingPhase()); - _phase = Phase::MODEL; - - initializeUseDef(); - sweepGarbageOperands(); - // Call graph verifications for the MODEL phase { - assert(verifier::DAGChecker().verify(*this)); - assert(verifier::EdgeConsistencyChecker().verify(*this)); + // Except for edge consistency, the user might have been given a bad model + // so here it throws an execption rather than assertion. + if (!verifier::InputOutputChecker().verify(*this)) + throw std::runtime_error{"One of model input and output operands does not exist."}; + if (!verifier::DAGChecker().verify(*this)) + throw std::runtime_error{"The graph is cyclic."}; + assert(verifier::EdgeChecker().verify(*this)); } + + // Check shape independent operation feature + // - Operand type + // - Shape independent parameter + OperationValidator{*this}(); } void Graph::initializeUseDef() { - operations().iterate([&](const OperationIndex &index, const Operation &node) -> void { - auto outputs = node.getOutputs(); - for (auto output : outputs) + operations().iterate([&](const OperationIndex &index, const IOperation &node) -> void { + const auto &outputs = node.getOutputs(); + for (auto &&output : outputs | ir::Remove::UNDEFINED) { operands().at(output).setDef(index); } - for (auto input : node.getInputs() | ir::Remove::UNDEFINED) + for (auto &&input : node.getInputs() | ir::Remove::UNDEFINED) { operands().at(input).insertUse(index); } }); } -void Graph::sweepGarbageOperands() +std::vector<ir::OperationIndex> Graph::topolSortOperations() const { - // Remove operands that are not used by any operations, except Graph inputs/outputs - ir::OperandIndexMap<bool> visited; - - operations().iterate([&](const OperationIndex &, const Operation &node) { - for (auto ind : node.getInputs() + node.getOutputs()) - { - visited[ind] = true; - } - }); - - // Graph's inputs/outputs are always reachable - for (auto ind : getInputs() + getOutputs()) - { - visited[ind] = true; - } - - operands().iterate([&](const OperandIndex &ind, const Operand &) { - if (!visited[ind]) + std::vector<ir::OperationIndex> ret; + util::Set<ir::OperationIndex> unvisited; + operations().iterate( + [&](const ir::OperationIndex &index, const ir::IOperation &) { unvisited.add(index); }); + + std::function<void(const ir::OperationIndex &, const ir::IOperation &)> dfs = + [&](const ir::OperationIndex &index, const ir::IOperation &op) -> void { + if (!unvisited.contains(index)) + return; + unvisited.remove(index); + + for (const auto &output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) { - VERBOSE(Graph::sweepGarbageOperands) << "Sweep garbage operand " << ind.value() << std::endl; - operands().remove(ind); + const auto &operand = operands().at(output); + for (const auto &use : operand.getUses()) + { + dfs(use, operations().at(use)); + } } - }); + ret.push_back(index); + }; + operations().iterate(dfs); + + assert(unvisited.empty()); // All of the nodes must have been visited + // Reversing Postorder DFS result to make it sorted in topoligical order + std::reverse(ret.begin(), ret.end()); + return ret; } } // namespace ir diff --git a/runtime/onert/core/src/ir/Graph.test.cc b/runtime/onert/core/src/ir/Graph.test.cc new file mode 100644 index 000000000..144500745 --- /dev/null +++ b/runtime/onert/core/src/ir/Graph.test.cc @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/Graph.h" +#include "ir/operation/BinaryArithmetic.h" + +#include <gtest/gtest.h> + +TEST(Graph, neg_inputs_and_outputs) +{ + onert::ir::Graph graph; + + onert::ir::OperandIndex index0{0u}; + onert::ir::OperandIndex index1{1u}; + + graph.addInput({index0}); + graph.addInput({index1}); + + onert::ir::OperandIndex index10{10u}; + onert::ir::OperandIndex index11{11u}; + onert::ir::OperandIndex index12{12u}; + + graph.addOutput({index10}); + graph.addOutput({index11}); + graph.addOutput({index12}); + + ASSERT_EQ(graph.getInputs().size(), 2); + ASSERT_EQ(graph.getOutputs().size(), 3); + + onert::ir::IOIndex io_index0{0}; + onert::ir::IOIndex io_index1{1}; + onert::ir::IOIndex io_index2{2}; + + ASSERT_EQ(graph.getInputs().at(io_index0), 0); + ASSERT_EQ(graph.getInputs().at(io_index1), 1); + + ASSERT_EQ(graph.getOutputs().at(io_index0), 10); + ASSERT_EQ(graph.getOutputs().at(io_index1), 11); + ASSERT_EQ(graph.getOutputs().at(io_index2), 12); + + EXPECT_THROW(graph.getOutputs().at(onert::ir::IOIndex{3}), std::out_of_range); +} + +using namespace onert::ir; + +OperationIndex addAddOperation(Graph &graph, const OperandIndexSequence inputs, + const OperandIndexSequence outputs) +{ + // Add "ADD" operation + operation::BinaryArithmetic::Param param; + param.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD; + param.activation = Activation::NONE; + return graph.addOperation(std::make_unique<operation::BinaryArithmetic>(inputs, outputs, param)); +} + +TEST(Graph, OneOpGraphSimpleValid) +{ + // Simple Graph with just one Add operation + + Graph graph; + + // Add tensors + Shape shape{1, 2, 2, 1}; + TypeInfo type{DataType::FLOAT32}; + auto lhs = graph.addOperand(shape, type); + auto rhs = graph.addOperand(shape, type); + auto res = graph.addOperand(shape, type); + + addAddOperation(graph, {lhs, rhs}, {res}); + + // Set model inputs/outputs + graph.addInput(lhs); + graph.addInput(rhs); + graph.addOutput(res); + + graph.verify(); + + SUCCEED(); +} + +TEST(Graph, neg_InvalidGraph_BadInput) +{ + Graph graph; + + // Add tensors + Shape shape{1, 2, 2, 1}; + TypeInfo type{DataType::FLOAT32}; + auto in = graph.addOperand(shape, type); + auto out = graph.addOperand(shape, type); + + // Set model inputs/outputs + graph.addInput(in); + graph.addOutput(out); + graph.addInput(OperandIndex{89}); // Non-exisiting operand! + + EXPECT_ANY_THROW(graph.verify()); +} + +TEST(Graph, neg_InvalidGraph_BadOutput) +{ + Graph graph; + + // Add tensors + Shape shape{1, 2, 2, 1}; + TypeInfo type{DataType::FLOAT32}; + auto in = graph.addOperand(shape, type); + auto out = graph.addOperand(shape, type); + + // Set model inputs/outputs + graph.addInput(in); + graph.addOutput(out); + graph.addOutput(OperandIndex{12}); // Non-exisiting operand! + + EXPECT_ANY_THROW(graph.verify()); +} + +TEST(Graph, neg_InvalidAddOperation_BadInputIndex) +{ + Graph graph; + + // Add tensors + Shape shape{1, 2, 2, 1}; + TypeInfo type{DataType::FLOAT32}; + auto lhs = graph.addOperand(shape, type); + auto rhs = graph.addOperand(shape, type); + auto res = graph.addOperand(shape, type); + + // Set model inputs/outputs + graph.addInput(lhs); + graph.addInput(rhs); + graph.addOutput(res); + + ASSERT_FALSE(addAddOperation(graph, {lhs, OperandIndex{99}}, {res}).valid()); +} diff --git a/runtime/onert/core/src/ir/GraphIterator.cc b/runtime/onert/core/src/ir/GraphIterator.cc deleted file mode 100644 index 4bea1a55d..000000000 --- a/runtime/onert/core/src/ir/GraphIterator.cc +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "GraphIterator.h" - -#include "ir/OperationIndexMap.h" -#include "compiler/LoweredGraph.h" - -namespace onert -{ -namespace ir -{ - -// -// Graph::DefaultIterator -// - -template <bool is_const> -void DefaultIterator<is_const>::iterate(GraphRef graph, const IterFn &fn) const -{ - graph.operations().iterate( - [&](const OperationIndex &index, NodeRef node) -> void { fn(index, node); }); -} - -// -// Graph::PostDfsIterator -// - -template <bool is_const> -void PostDfsIterator<is_const>::iterate(GraphRef graph, const IterFn &fn) const -{ - assert(!graph.isBuildingPhase()); // Restrict iteration condition - - OperationIndexMap<bool> visited; - graph.operations().iterate([&](const OperationIndex &index, NodeRef) { visited[index] = false; }); - - std::function<void(const OperationIndex &, NodeRef)> dfs_recursive = - [&](const OperationIndex &index, NodeRef node) -> void { - if (visited[index]) - return; - visited[index] = true; - - for (const auto output : node.getOutputs() | Remove::DUPLICATED) - { - const auto &operand = graph.operands().at(output); - for (const auto &use : operand.getUses()) - { - dfs_recursive(use, graph.operations().at(use)); - } - } - - fn(index, node); - }; - - graph.operations().iterate(dfs_recursive); - - // All of the operations(nodes) must have been visited. - assert(std::all_of(visited.begin(), visited.end(), - [](const std::pair<const OperationIndex, bool> &v) { return v.second; })); -} - -template <bool is_const> -void PostDfsIterator<is_const>::iterateOpSeqs(LoweredGraphRef lowered_graph, - const OpSeqIterFn &fn) const -{ - std::unordered_map<OpSequenceIndex, bool> visited; - lowered_graph.op_seqs().iterate( - [&](const OpSequenceIndex &index, OpSequenceRef) { visited[index] = false; }); - - std::function<void(const OpSequenceIndex &, OpSequenceRef)> dfs_recursive = - [&](const OpSequenceIndex &index, OpSequenceRef op_seq) -> void { - if (visited[index]) - return; - visited[index] = true; - - for (const auto output : op_seq.getOutputs() | Remove::DUPLICATED) - { - const auto &operand = lowered_graph.graph().operands().at(output); - for (const auto &use : operand.getUses()) - { - const auto use_op_seq_index = lowered_graph.op_seqs().getOperation(use); - dfs_recursive(use_op_seq_index, lowered_graph.op_seqs().at(use_op_seq_index)); - } - } - - fn(index, op_seq); - }; - - lowered_graph.op_seqs().iterate(dfs_recursive); - - // All of the operations(nodes) must have been visited. - assert(std::all_of(visited.begin(), visited.end(), - [](const std::pair<const OpSequenceIndex, bool> &v) { return v.second; })); -} - -// Explicit instantiations to have implementation in the source file. -// NOTE If these instatiations were in the top of this file, `iterate` is compiled and saved in -// `GraphIterator.cc.o` but `iterateOpSeqs`. This happens only when cross-building for Android. -// (Maybe a bug of NDK toolchain(clang)?) - -template class DefaultIterator<true>; -template class DefaultIterator<false>; - -template class PostDfsIterator<true>; -template class PostDfsIterator<false>; - -} // namespace ir -} // namespace onert diff --git a/runtime/onert/core/src/ir/GraphIterator.h b/runtime/onert/core/src/ir/GraphIterator.h deleted file mode 100644 index b54314e0e..000000000 --- a/runtime/onert/core/src/ir/GraphIterator.h +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __ONERT_IR_GRAPH_ITERATOR_H__ -#define __ONERT_IR_GRAPH_ITERATOR_H__ - -#include <type_traits> - -#include "ir/Index.h" - -namespace onert -{ -namespace compiler -{ -class LoweredGraph; -} // namespace compiler -} // namespace onert - -namespace onert -{ -namespace ir -{ - -class Graph; -class Operation; -class OpSequence; - -template <bool is_const> class Iterator -{ -public: - using GraphRef = typename std::conditional<is_const, const Graph &, Graph &>::type; - using IndexRef = const OperationIndex &; - using NodeRef = typename std::conditional<is_const, const Operation &, Operation &>::type; - using IterFn = std::function<void(IndexRef, NodeRef)>; - -public: - virtual ~Iterator() = default; - virtual void iterate(GraphRef graph, const IterFn &fn) const = 0; -}; - -template <bool is_const = false> class DefaultIterator final : public Iterator<is_const> -{ -public: - using GraphRef = typename Iterator<is_const>::GraphRef; - using IndexRef = typename Iterator<is_const>::IndexRef; - using NodeRef = typename Iterator<is_const>::NodeRef; - using IterFn = typename Iterator<is_const>::IterFn; - -public: - void iterate(GraphRef graph, const IterFn &fn) const; -}; -using DefaultConstIterator = DefaultIterator<true>; - -template <bool is_const = false> class PostDfsIterator final : public Iterator<is_const> -{ -public: - using GraphRef = typename Iterator<is_const>::GraphRef; - using IndexRef = typename Iterator<is_const>::IndexRef; - using NodeRef = typename Iterator<is_const>::NodeRef; - using IterFn = typename Iterator<is_const>::IterFn; - using LoweredGraphRef = - typename std::conditional<is_const, const typename compiler::LoweredGraph &, - typename compiler::LoweredGraph &>::type; - using OpSequenceRef = typename std::conditional<is_const, const OpSequence &, OpSequence &>::type; - using OpSeqIndexRef = const OpSequenceIndex &; - using OpSeqIterFn = std::function<void(OpSeqIndexRef, OpSequenceRef)>; - -public: - void iterate(GraphRef graph, const IterFn &fn) const; - void iterateOpSeqs(LoweredGraphRef lowered_graph, const OpSeqIterFn &f) const; -}; -using PostDfsConstIterator = PostDfsIterator<true>; - -} // namespace ir -} // namespace onert - -#endif // __ONERT_IR_GRAPH_ITERATOR_H__ diff --git a/runtime/onert/core/src/ir/LayoutSet.cc b/runtime/onert/core/src/ir/LayoutSet.cc index bd3f438ad..732460aa2 100644 --- a/runtime/onert/core/src/ir/LayoutSet.cc +++ b/runtime/onert/core/src/ir/LayoutSet.cc @@ -23,7 +23,7 @@ namespace ir LayoutSet::LayoutSet(std::initializer_list<Layout> layouts) { - for (auto layout : layouts) + for (auto &&layout : layouts) { _set.insert(layout); } @@ -32,7 +32,7 @@ LayoutSet::LayoutSet(std::initializer_list<Layout> layouts) LayoutSet LayoutSet::operator|(const LayoutSet &other) const { auto ret = *this; - for (auto layout : other) + for (auto &&layout : other) { ret.add(layout); } @@ -42,7 +42,7 @@ LayoutSet LayoutSet::operator|(const LayoutSet &other) const LayoutSet LayoutSet::operator&(const LayoutSet &other) const { LayoutSet ret; - for (auto layout : other) + for (auto &&layout : other) { if (contains(layout)) { @@ -55,7 +55,7 @@ LayoutSet LayoutSet::operator&(const LayoutSet &other) const LayoutSet LayoutSet::operator-(const LayoutSet &other) const { auto ret = *this; - for (auto layout : other) + for (auto &&layout : other) { ret.remove(layout); } diff --git a/runtime/onert/core/src/ir/LayoutSet.h b/runtime/onert/core/src/ir/LayoutSet.h index 6ce4e38c6..be077f2f0 100644 --- a/runtime/onert/core/src/ir/LayoutSet.h +++ b/runtime/onert/core/src/ir/LayoutSet.h @@ -17,6 +17,7 @@ #ifndef __ONERT_IR_LAYOUT_SET_H__ #define __ONERT_IR_LAYOUT_SET_H__ +#include <cstdint> #include <initializer_list> #include <unordered_set> diff --git a/runtime/onert/core/src/ir/LayoutSet.test.cc b/runtime/onert/core/src/ir/LayoutSet.test.cc new file mode 100644 index 000000000..fc956abe8 --- /dev/null +++ b/runtime/onert/core/src/ir/LayoutSet.test.cc @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LayoutSet.h" + +#include <gtest/gtest.h> + +using onert::ir::Layout; +using onert::ir::LayoutSet; + +TEST(ir_LayoutSet, neg_add_remove) +{ + LayoutSet set{Layout::NCHW}; + set.remove(Layout::NHWC); + ASSERT_EQ(set.size(), 1); + set.add(Layout::NHWC); + ASSERT_EQ(set.size(), 2); + set.remove(Layout::NHWC); + ASSERT_EQ(set.size(), 1); + set.remove(Layout::NCHW); + ASSERT_EQ(set.size(), 0); + set.remove(Layout::NCHW); + ASSERT_EQ(set.size(), 0); +} + +TEST(ir_LayoutSet, neg_add_twice) +{ + LayoutSet set; + set.add(Layout::NHWC); + ASSERT_EQ(set.size(), 1); + set.add(Layout::NHWC); + ASSERT_EQ(set.size(), 1); +} + +TEST(ir_LayoutSet, set_operators) +{ + LayoutSet set1{Layout::NCHW}; + LayoutSet set2{Layout::NHWC}; + LayoutSet set3 = set1 | set2; + + ASSERT_EQ(set3.size(), 2); + + ASSERT_EQ((set3 - set1).size(), 1); + ASSERT_EQ((set3 - set1).contains(Layout::NHWC), true); + ASSERT_EQ((set3 - set2).size(), 1); + ASSERT_EQ((set3 - set2).contains(Layout::NCHW), true); + ASSERT_EQ((set3 - set3).size(), 0); + + ASSERT_EQ((set3 & set1).size(), 1); + ASSERT_EQ((set3 & set1).contains(Layout::NCHW), true); + ASSERT_EQ((set3 & set2).size(), 1); + ASSERT_EQ((set3 & set2).contains(Layout::NHWC), true); + ASSERT_EQ((set1 & set2).size(), 0); +} diff --git a/runtime/onert/core/src/ir/MockNode.h b/runtime/onert/core/src/ir/MockNode.h new file mode 100644 index 000000000..0e7ed977b --- /dev/null +++ b/runtime/onert/core/src/ir/MockNode.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_TEST_GRAPH_MOCK_NODE_H__ +#define __ONERT_TEST_GRAPH_MOCK_NODE_H__ + +#include "ir/Operation.h" +#include "ir/OperandIndexSequence.h" + +namespace onert_test +{ +namespace ir +{ + +class SimpleMock : public onert::ir::Operation +{ +public: + SimpleMock(const onert::ir::OperandIndexSequence &inputs, + const onert::ir::OperandIndexSequence &outputs) + : Operation{onert::ir::OperandConstraint::createAny()} + { + setInputs(inputs); + setOutputs(outputs); + } + +public: + void accept(onert::ir::OperationVisitor &) const override {} + onert::ir::OpCode opcode() const final { return onert::ir::OpCode::Invalid; } +}; + +} // namespace ir +} // namespace onert_test + +#endif // __ONERT_TEST_GRAPH_MOCK_NODE_H__ diff --git a/runtime/onert/core/src/ir/OpSequence.cc b/runtime/onert/core/src/ir/OpSequence.cc deleted file mode 100644 index e2b989d8c..000000000 --- a/runtime/onert/core/src/ir/OpSequence.cc +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/OpSequence.h" - -#include "ir/Operations.h" -#include "ir/OperationVisitor.h" -#include <sstream> - -namespace -{ - -std::string getStrFromIndice(const onert::ir::OperandIndexSequence &indice) -{ - std::string str; - for (const auto &ind : indice) - { - str += std::to_string(ind.value()); - str.push_back(','); - } - if (str.back() == ',') - str.pop_back(); - - return str; -} -} - -namespace onert -{ -namespace ir -{ - -OpSequence::OpSequence(Layout layout) : _layout{layout}, _has_dynamic_tensor{false} -{ - // DO NOTHING -} - -void OpSequence::accept(OperationVisitor &v) const { v.visit(*this); } - -// TODO: Impl Dumper instead of this method -std::string getStrFromOpSeq(const OpSequence &op_seq, const Operations &operations) -{ - // " OpSequence IN(0,1,2) -> { op0(0,1,2:3), op1(3:4), op2(4:5) } -> OUT(5)" - std::stringstream ss; - ss << " OpSequence IN(" << getStrFromIndice(op_seq.getInputs()) << ") -> {"; - for (const auto &op_idx : op_seq) - { - ss << " " << op_idx.value() << "(" << operations.at(op_idx).name() << ":" - << getStrFromIndice(operations.at(op_idx).getInputs()) << ":" - << getStrFromIndice(operations.at(op_idx).getOutputs()) << ")"; - } - ss << " } -> OUT(" << getStrFromIndice(op_seq.getOutputs()) << ")"; - return ss.str(); -} - -void OpSequence::remove(const OperationIndex &index) -{ - assert(exist(index)); - for (auto it = _operations.cbegin(); it != _operations.cend(); ++it) - { - if (*it == index) - { - _operations.erase(it); - break; - } - } -} - -bool OpSequence::exist(const OperationIndex &index) const -{ - for (const auto &inner_op_idx : _operations) - { - if (inner_op_idx == index) - { - return true; - } - } - return false; -} - -} // namespace ir -} // namespace onert diff --git a/runtime/onert/core/src/ir/OpSequences.cc b/runtime/onert/core/src/ir/OpSequences.cc deleted file mode 100644 index 68884783e..000000000 --- a/runtime/onert/core/src/ir/OpSequences.cc +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/OpSequences.h" -#include "util/logging.h" -#include <memory> - -#include <cassert> -#include <string> - -namespace onert -{ -namespace ir -{ - -OpSequenceIndex OpSequences::emplace(const OperationIndex &index, Layout layout) -{ - std::unique_ptr<OpSequence> op_seq = std::make_unique<OpSequence>(layout); - op_seq->appendOperation(index); - const OpSequenceIndex &seq_index = push(std::move(op_seq)); - cacheSequenceIndex(seq_index, index); - return seq_index; -} - -OpSequenceIndex OpSequences::emplace(std::unique_ptr<OpSequence> &&op_seq) -{ - auto &operations = op_seq->operations(); - const OpSequenceIndex &seq_index = push(std::move(op_seq)); - for (const auto &op_idx : operations) - { - cacheSequenceIndex(seq_index, op_idx); - } - return seq_index; -} - -void OpSequences::cacheSequenceIndex(const OpSequenceIndex &seq_index, - const OperationIndex &op_index) const -{ - _seq_indexes.emplace(op_index, seq_index); -} - -OpSequenceIndex *OpSequences::findSequenceIndex(const OperationIndex &operation_index) const -{ - // If opration_index is cached, return sequence_index from cache - if (_seq_indexes.count(operation_index)) - { - auto &op_seq_index = _seq_indexes.at(operation_index); - if (_objects.count(op_seq_index) && _objects.at(op_seq_index)->exist(operation_index)) - { - return &op_seq_index; - } - else - { - _seq_indexes.erase(operation_index); - return nullptr; - } - } - return nullptr; -} - -bool OpSequences::containsOperation(const OperationIndex &operation_index) const -{ - return findOperation(operation_index).valid(); -} - -OpSequenceIndex OpSequences::getOperation(const OperationIndex &operation_index) const -{ - OpSequenceIndex ret = findOperation(operation_index); - assert(ret.valid()); - return ret; -} - -void OpSequences::removeFromOpSequence(const OperationIndex &operation_index) -{ - const auto op_seq_index = findOperation(operation_index); - auto &op_seq = at(op_seq_index); - _seq_indexes.erase(operation_index); - op_seq.remove(operation_index); - if (op_seq.size() == 0) - { - remove(op_seq_index); - } -} - -OpSequenceIndex OpSequences::findOperation(const OperationIndex &operation_index) const -{ - if (OpSequenceIndex *op_seq_index = findSequenceIndex(operation_index)) - return *op_seq_index; - - for (auto &e : _objects) - { - OpSequence &object = *e.second; - auto it = find(object.operations().begin(), object.operations().end(), operation_index); - if (it != object.operations().end()) - { - cacheSequenceIndex(e.first, operation_index); - return e.first; - } - } - throw std::runtime_error("Operation not found"); -} - -void dumpOpSequences(const OpSequences &op_seqs, const Operations &operations) -{ - op_seqs.iterate([&](const OpSequenceIndex &idx, const OpSequence &op_seq) { - VERBOSE(OpSequences) << idx.value() << "] " << getStrFromOpSeq(op_seq, operations) << std::endl; - }); -} - -} // namespace ir -} // namespace onert diff --git a/runtime/onert/core/src/ir/Operand.cc b/runtime/onert/core/src/ir/Operand.cc index e29c7a6ec..18981dbf1 100644 --- a/runtime/onert/core/src/ir/Operand.cc +++ b/runtime/onert/core/src/ir/Operand.cc @@ -46,5 +46,11 @@ void Operand::setDef(const OperationIndex &idx) { _def = idx; } void Operand::unsetDef() { _def = OperationIndex{}; } +void Operand::clearDefUse() +{ + unsetDef(); + _uses.clear(); +} + } // namespace ir } // namespace onert diff --git a/runtime/onert/core/src/ir/Operand.test.cc b/runtime/onert/core/src/ir/Operand.test.cc new file mode 100644 index 000000000..0b858792a --- /dev/null +++ b/runtime/onert/core/src/ir/Operand.test.cc @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/Graph.h" + +#include "MockNode.h" +#include "verifier/Verifier.h" + +#include <gtest/gtest.h> + +#include <memory> +#include <typeindex> + +namespace +{ + +using IndexSet = onert::ir::OperandIndexSequence; +using Mock = onert_test::ir::SimpleMock; + +} // namespace + +TEST(ir_Operand, neg_usedef) +{ + onert::ir::Graph graph; + onert::ir::verifier::DAGChecker verifier; + + onert::ir::Shape shape(3); + onert::ir::TypeInfo type{onert::ir::DataType::INT32}; + + // Model Input/Output + auto input_operand = graph.addOperand(shape, type); + auto output_operand = graph.addOperand(shape, type); + + graph.addInput(input_operand); + graph.addOutput(output_operand); + + // MockNode1 + auto operand_index1 = graph.addOperand(shape, type); + auto mocknode_index1 = + graph.addOperation(std::make_unique<Mock>(IndexSet{input_operand}, IndexSet{operand_index1})); + + // MockNode2 + auto operand_index2 = graph.addOperand(shape, type); + auto mocknode_index2 = + graph.addOperation(std::make_unique<Mock>(IndexSet{input_operand}, IndexSet{operand_index2})); + + // MockNode3(two input) + auto multiinput_index = graph.addOperation( + std::make_unique<Mock>(IndexSet{operand_index1, operand_index2}, IndexSet{output_operand})); + + graph.verify(); + + ASSERT_TRUE(verifier.verify(graph)); + + // Check def + ASSERT_EQ(graph.operands().at(operand_index1).getDef(), mocknode_index1); + ASSERT_EQ(graph.operands().at(operand_index2).getDef(), mocknode_index2); + ASSERT_EQ(graph.operands().at(output_operand).getDef(), multiinput_index); + + ASSERT_NE(graph.operands().at(operand_index1).getDef(), mocknode_index2); + ASSERT_NE(graph.operands().at(operand_index1).getDef(), multiinput_index); + + // Check use + ASSERT_EQ(graph.operands().at(input_operand).getUses().contains(mocknode_index1), true); + ASSERT_EQ(graph.operands().at(input_operand).getUses().contains(mocknode_index2), true); + ASSERT_EQ(graph.operands().at(input_operand).getUses().contains(multiinput_index), false); + ASSERT_EQ(graph.operands().at(operand_index1).getUses().contains(multiinput_index), true); + ASSERT_EQ(graph.operands().at(operand_index2).getUses().contains(multiinput_index), true); + + ASSERT_EQ(graph.operands().at(input_operand).getUses().size(), 2); + ASSERT_EQ(graph.operands().at(operand_index1).getUses().size(), 1); + ASSERT_EQ(graph.operands().at(output_operand).getUses().size(), 0); +} diff --git a/runtime/onert/core/src/ir/OperandIndexSequence.cc b/runtime/onert/core/src/ir/OperandIndexSequence.cc index 73f928280..a15b6d0d6 100644 --- a/runtime/onert/core/src/ir/OperandIndexSequence.cc +++ b/runtime/onert/core/src/ir/OperandIndexSequence.cc @@ -31,7 +31,7 @@ OperandIndexSequence::OperandIndexSequence(std::initializer_list<OperandIndex> l OperandIndexSequence::OperandIndexSequence(std::initializer_list<int32_t> list) { - for (auto val : list) + for (auto &&val : list) { _vec.emplace_back(static_cast<uint32_t>(val)); } @@ -39,7 +39,7 @@ OperandIndexSequence::OperandIndexSequence(std::initializer_list<int32_t> list) OperandIndexSequence::OperandIndexSequence(std::initializer_list<uint32_t> list) { - for (auto val : list) + for (auto &&val : list) { _vec.emplace_back(val); } @@ -55,6 +55,11 @@ void OperandIndexSequence::replace(const OperandIndex &from, const OperandIndex std::replace(_vec.begin(), _vec.end(), from, to); } +bool OperandIndexSequence::operator==(const OperandIndexSequence &other) const +{ + return _vec == other._vec; +} + OperandIndexSequence OperandIndexSequence::operator+(const OperandIndexSequence &other) const { OperandIndexSequence ret = *this; @@ -62,10 +67,10 @@ OperandIndexSequence OperandIndexSequence::operator+(const OperandIndexSequence return ret; } -std::ostream &operator<<(std::ostream &o, const OperandIndexSequence &op_seq) +std::ostream &operator<<(std::ostream &o, const OperandIndexSequence &operand_seq) { std::string delimeter; - for (const auto &ind : op_seq._vec) + for (const auto &ind : operand_seq._vec) { o << delimeter << ind; delimeter = ','; diff --git a/runtime/onert/core/src/ir/OperandIndexSequence.test.cc b/runtime/onert/core/src/ir/OperandIndexSequence.test.cc new file mode 100644 index 000000000..588c4e419 --- /dev/null +++ b/runtime/onert/core/src/ir/OperandIndexSequence.test.cc @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/OperandIndexSequence.h" + +#include <gtest/gtest.h> + +using onert::ir::OperandIndex; +using onert::ir::OperandIndexSequence; + +TEST(ir_OperandIndexSequence, neg_append) +{ + OperandIndexSequence iset{0, 2, 4, 8}; + + ASSERT_EQ(iset.size(), 4); + + iset.append(OperandIndex{10}); + + ASSERT_EQ(iset.size(), 5); + + onert::ir::IOIndex index1{1}; + onert::ir::IOIndex index2{4}; + + ASSERT_EQ(iset.at(index1), 2); + ASSERT_EQ(iset.at(index2), 10); + + ASSERT_TRUE(iset.contains(OperandIndex{2})); + ASSERT_TRUE(iset.contains(OperandIndex{10})); + ASSERT_FALSE(iset.contains(OperandIndex{11})); +} + +TEST(graph_OperandIndexSequence, neg_replace) +{ + OperandIndexSequence iset{0, 1, 2, 3}; + + iset.replace(OperandIndex{1}, OperandIndex{9}); + ASSERT_FALSE(iset.contains(OperandIndex{1})); + ASSERT_TRUE(iset.contains(OperandIndex{9})); +} diff --git a/runtime/onert/core/src/ir/Operands.cc b/runtime/onert/core/src/ir/Operands.cc index ab32e478a..f8cfd16ef 100644 --- a/runtime/onert/core/src/ir/Operands.cc +++ b/runtime/onert/core/src/ir/Operands.cc @@ -29,7 +29,7 @@ Operands::Operands(const Operands &obj) obj.iterate([&](const OperandIndex &index, const Operand &operand) { _objects.emplace(index, std::make_unique<Operand>(operand)); }); - _index_count = obj._index_count; + _next_index = obj._next_index; } } // namespace ir diff --git a/runtime/onert/core/src/ir/Operands.test.cc b/runtime/onert/core/src/ir/Operands.test.cc new file mode 100644 index 000000000..aff228b10 --- /dev/null +++ b/runtime/onert/core/src/ir/Operands.test.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/Operands.h" + +#include <gtest/gtest.h> + +TEST(ir_Operands, neg_set_test) +{ + onert::ir::Operands set; + + onert::ir::Shape shape0{1, 2, 3}; + + onert::ir::Shape shape1(4); + shape1.dim(0) = 10; + shape1.dim(1) = 20; + shape1.dim(2) = 30; + shape1.dim(3) = 40; + + onert::ir::TypeInfo type{onert::ir::DataType::INT32}; + + set.emplace(shape0, type); + set.emplace(shape1, type); + + ASSERT_EQ(set.exist(onert::ir::OperandIndex{0u}), true); + ASSERT_EQ(set.exist(onert::ir::OperandIndex{1u}), true); + ASSERT_EQ(set.exist(onert::ir::OperandIndex{2u}), false); + + ASSERT_EQ(set.at(onert::ir::OperandIndex{0u}).shape().dim(0), 1); + ASSERT_EQ(set.at(onert::ir::OperandIndex{0u}).shape().dim(1), 2); + ASSERT_EQ(set.at(onert::ir::OperandIndex{0u}).shape().dim(2), 3); +} diff --git a/runtime/onert/core/src/ir/Operation.cc b/runtime/onert/core/src/ir/Operation.cc index 04be8c0d9..64792525d 100644 --- a/runtime/onert/core/src/ir/Operation.cc +++ b/runtime/onert/core/src/ir/Operation.cc @@ -24,22 +24,33 @@ namespace ir { Operation::Operation(OperandConstraint input_constr, const OperandIndexSequence &inputs, - const OperandIndexSequence &outputs) - : _input_constr{input_constr}, _inputs{inputs}, _outputs{outputs} + const OperandIndexSequence &outputs, OperandConstraint output_constr) + : _input_constr{input_constr}, _output_constr{output_constr} { + setInputs(inputs); + setOutputs(outputs); } -Operation::Operation(OperandConstraint input_constr) : _input_constr{input_constr} {} +Operation::Operation(OperandConstraint input_constr, OperandConstraint output_constr) + : _input_constr{input_constr}, _output_constr{output_constr} +{ +} Operation::~Operation() = default; void Operation::setInputs(const OperandIndexSequence &indexes) { - assert(_input_constr.check(indexes.size())); + if (!_input_constr.check(indexes.size())) + throw std::runtime_error{"Invalid number of input tensors for this operation."}; _inputs = indexes; } -void Operation::setOutputs(const OperandIndexSequence &indexes) { _outputs = indexes; } +void Operation::setOutputs(const OperandIndexSequence &indexes) +{ + if (!_output_constr.check(indexes.size())) + throw std::runtime_error{"Invalid number of output tensors for this operation."}; + _outputs = indexes; +} void Operation::replaceInputs(const OperandIndex &from, const OperandIndex &to) { diff --git a/runtime/onert/core/src/ir/Operation.test.cc b/runtime/onert/core/src/ir/Operation.test.cc new file mode 100644 index 000000000..b3c4e852d --- /dev/null +++ b/runtime/onert/core/src/ir/Operation.test.cc @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/Graph.h" +#include "ir/Index.h" +#include "ir/OperandIndexSequence.h" +#include "ir/operation/Concat.h" +#include "ir/operation/Conv2D.h" + +#include <gtest/gtest.h> + +#include <memory> +#include <stdexcept> + +using Index = onert::ir::IOIndex; +using IndexSet = onert::ir::OperandIndexSequence; + +TEST(ir_Operation_setIO, operation_setIO_conv) +{ + onert::ir::Graph graph; + + onert::ir::Shape shape{3}; + onert::ir::TypeInfo type{onert::ir::DataType::INT32}; + + // Add Conv + using Graph = onert::ir::operation::Conv2D; + + auto input_operand = graph.addOperand(shape, type); + auto kernel_operand = graph.addOperand(shape, type); + auto bias_operand = graph.addOperand(shape, type); + IndexSet inputs{input_operand, kernel_operand, bias_operand}; + + Graph::Param conv_params; + conv_params.padding.type = onert::ir::PaddingType::SAME; + conv_params.stride.horizontal = 1; + conv_params.stride.vertical = 1; + conv_params.activation = onert::ir::Activation::NONE; + + auto output_operand = graph.addOperand(shape, type).value(); + IndexSet outputs{output_operand}; + + auto conv = std::make_unique<Graph>(inputs, outputs, conv_params); + + ASSERT_NE(conv, nullptr); + ASSERT_EQ(conv->getInputs().at(Index{0}).value(), inputs.at(0).value()); + conv->setInputs({8, 9, 10}); + ASSERT_NE(conv->getInputs().at(Index{0}).value(), inputs.at(0).value()); + ASSERT_EQ(conv->getInputs().at(Index{0}).value(), 8); +} + +TEST(ir_Operation_setIO, neg_operation_setIO_concat) +{ + onert::ir::Graph graph; + + onert::ir::Shape shape{3}; + + onert::ir::TypeInfo type{onert::ir::DataType::INT32}; + + using Graph = onert::ir::operation::Concat; + + // Add Concat + IndexSet inputs; + for (int i = 0; i < 6; ++i) + { + inputs.append(graph.addOperand(shape, type)); + } + + Graph::Param concat_params{0}; + + auto output_operand = graph.addOperand(shape, type).value(); + IndexSet outputs{output_operand}; + + auto concat = std::make_unique<Graph>(inputs, outputs, concat_params); + + ASSERT_NE(concat, nullptr); + ASSERT_EQ(concat->getInputs().size(), 6); + ASSERT_EQ(concat->getInputs().at(Index{0}).value(), inputs.at(0).value()); + + concat->setInputs({80, 6, 9, 11}); + ASSERT_EQ(concat->getInputs().size(), 4); + ASSERT_NE(concat->getInputs().at(Index{0}).value(), inputs.at(0).value()); + ASSERT_EQ(concat->getInputs().at(Index{0}).value(), 80); + ASSERT_EQ(concat->getInputs().at(Index{2}).value(), 9); + ASSERT_THROW(concat->getInputs().at(Index{5}), std::out_of_range); +} diff --git a/runtime/onert/core/src/ir/OperationCloner.cc b/runtime/onert/core/src/ir/OperationCloner.cc index b4e60f0bc..64e1cc807 100644 --- a/runtime/onert/core/src/ir/OperationCloner.cc +++ b/runtime/onert/core/src/ir/OperationCloner.cc @@ -23,6 +23,23 @@ namespace onert namespace ir { +namespace +{ + +class OperationCloner : public OperationVisitor +{ +public: +#define OP(Name) void visit(const operation::Name &o) override; +#include "ir/Operations.lst" +#undef OP + +public: + std::unique_ptr<Operation> releaseClone(); + +private: + std::unique_ptr<Operation> _return_op; +}; + #define OP(Name) \ void OperationCloner::visit(const operation::Name &o) \ { \ @@ -38,5 +55,14 @@ std::unique_ptr<Operation> OperationCloner::releaseClone() return std::move(_return_op); } +} // namespace + +std::unique_ptr<Operation> clone(const IOperation &operation) +{ + OperationCloner cloner; + operation.accept(cloner); + return cloner.releaseClone(); +} + } // namespace ir } // namespace onert diff --git a/runtime/onert/core/src/ir/OperationCloner.h b/runtime/onert/core/src/ir/OperationCloner.h index 0e8cda2a0..49297a05c 100644 --- a/runtime/onert/core/src/ir/OperationCloner.h +++ b/runtime/onert/core/src/ir/OperationCloner.h @@ -26,19 +26,7 @@ namespace onert namespace ir { -class OperationCloner : public OperationVisitor -{ -public: -#define OP(Name) void visit(const operation::Name &o) override; -#include "ir/Operations.lst" -#undef OP - -public: - std::unique_ptr<Operation> releaseClone(); - -private: - std::unique_ptr<Operation> _return_op; -}; +std::unique_ptr<Operation> clone(const IOperation &operation); } // namespace ir } // namespace onert diff --git a/runtime/onert/core/src/ir/OperationDumper.cc b/runtime/onert/core/src/ir/OperationDumper.cc index 48361f464..5aa4693ad 100644 --- a/runtime/onert/core/src/ir/OperationDumper.cc +++ b/runtime/onert/core/src/ir/OperationDumper.cc @@ -29,19 +29,21 @@ using namespace operation; namespace { -void dumpUnaryInputOp(const Operation &node, const std::string &adding_input = "") + +// Dump all input and output. +// Use this function when there is no special input or(and) output. +void dumpOpGeneric(const Operation &node, const std::string &adding_input = "") { VERBOSE(LIR) << "* " << node.name() << std::endl; - VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs().at(0) << ") " << adding_input - << std::endl; - VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl; + VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs() << ") " << adding_input << std::endl; + VERBOSE(LIR) << " - Output : Output(" << node.getOutputs() << ")" << std::endl; } -void dumpBinaryInputOp(const Operation &node, const std::string &adding_input = "") +void dumpUnaryInputOp(const Operation &node, const std::string &adding_input = "") { VERBOSE(LIR) << "* " << node.name() << std::endl; - VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs().at(0) << ", " << node.getInputs().at(0) - << ") " << adding_input << std::endl; + VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs().at(0) << ") " << adding_input + << std::endl; VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl; } @@ -53,18 +55,6 @@ void dumpConvOp(const Operation &node, const std::string &padding_type) << node.getInputs().at(Conv2D::Input::BIAS) << ")" << std::endl; VERBOSE(LIR) << " - Output : OFM(" << node.getOutputs().at(0) << ")" << std::endl; } - -void dumpPackingOp(const Operation &node) -{ - VERBOSE(LIR) << "* " << node.name() << std::endl; - std::string inputs; - for (auto i : node.getInputs()) - { - inputs += std::to_string(i.value()) + ","; - } - VERBOSE(LIR) << " - Inputs : Inputs(" << inputs << ")" << std::endl; - VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl; -} } // namespace OperationDumper::OperationDumper(const std::string &start_msg) @@ -72,41 +62,62 @@ OperationDumper::OperationDumper(const std::string &start_msg) VERBOSE(LIR) << start_msg << std::endl; } -void OperationDumper::visit(const ArgMax &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const ArgMinMax &node) +{ + std::string min_max = node.param().is_arg_max ? "(Max)" : "(Min)"; + VERBOSE(LIR) << "* " << node.name() << min_max << std::endl; + VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs().at(ArgMinMax::INPUT) << ") Axis(" + << node.getInputs().at(ArgMinMax::AXIS) << ") " << std::endl; + VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl; +} void OperationDumper::visit(const BatchToSpaceND &node) { std::string block_size = - "BlockSize(" + - std::to_string(node.getInputs().at(BatchToSpaceND::Input::BLOCK_SIZE).value()) + ")"; - dumpUnaryInputOp(node, block_size); + "BlockSize(" + std::to_string(node.getInputs().at(BatchToSpaceND::Input::BLOCK_SIZE).value()) + + ")"; + dumpOpGeneric(node, block_size); } -void OperationDumper::visit(const BinaryArithmetic &node) { dumpBinaryInputOp(node); } +void OperationDumper::visit(const BCQFullyConnected &node) +{ + VERBOSE(LIR) << "* " << node.name() << std::endl; + VERBOSE(LIR) << " - Inputs : IFM(" << node.getInputs().at(BCQFullyConnected::Input::INPUT) + << ") WeightsBinary(" + << node.getInputs().at(BCQFullyConnected::Input::WEIGHTS_BINARY) + << ") WeightsScales(" + << node.getInputs().at(BCQFullyConnected::Input::WEIGHTS_SCALES) + << ") WeightsClusters(" + << node.getInputs().at(BCQFullyConnected::Input::WEIGHTS_CLUSTERS) << ") Bias(" + << node.getInputs().at(BCQFullyConnected::Input::BIAS) << ")" << std::endl; + VERBOSE(LIR) << " - Output : OFM(" << node.getOutputs().at(0) << ")" << std::endl; +} + +void OperationDumper::visit(const BinaryArithmetic &node) { dumpOpGeneric(node); } -void OperationDumper::visit(const operation::BroadcastTo &node) { dumpBinaryInputOp(node); } +void OperationDumper::visit(const operation::BroadcastTo &node) { dumpOpGeneric(node); } -void OperationDumper::visit(const Comparison &node) { dumpBinaryInputOp(node); } +void OperationDumper::visit(const Comparison &node) { dumpOpGeneric(node); } -void OperationDumper::visit(const Concat &node) { dumpPackingOp(node); } +void OperationDumper::visit(const Concat &node) { dumpOpGeneric(node); } void OperationDumper::visit(const Conv2D &node) { std::string padding_type = - node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit"; + node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit"; dumpConvOp(node, padding_type); } -void OperationDumper::visit(const ConvertFp16ToFp32 &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const ConvertFp16ToFp32 &node) { dumpOpGeneric(node); } -void OperationDumper::visit(const ConvertFp32ToFp16 &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const ConvertFp32ToFp16 &node) { dumpOpGeneric(node); } -void OperationDumper::visit(const DepthToSpace &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const DepthToSpace &node) { dumpOpGeneric(node); } void OperationDumper::visit(const DepthwiseConv2D &node) { std::string padding_type = - node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit"; + node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit"; dumpConvOp(node, padding_type); } @@ -122,12 +133,12 @@ void OperationDumper::visit(const ElementwiseActivation &node) { params = " alpha value(" + std::to_string(node.param().alpha) + ")"; } - dumpUnaryInputOp(node, params); + dumpOpGeneric(node, params); } -void OperationDumper::visit(const ElementwiseBinary &node) { dumpBinaryInputOp(node); } +void OperationDumper::visit(const ElementwiseBinary &node) { dumpOpGeneric(node); } -void OperationDumper::visit(const ElementwiseUnary &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const ElementwiseUnary &node) { dumpOpGeneric(node); } void OperationDumper::visit(const EmbeddingLookup &node) { @@ -141,22 +152,31 @@ void OperationDumper::visit(const EmbeddingLookup &node) void OperationDumper::visit(const ExpandDims &node) { std::string axis = - "AXIS(" + std::to_string(node.getInputs().at(ExpandDims::Input::AXIS).value()) + ")"; + "AXIS(" + std::to_string(node.getInputs().at(ExpandDims::Input::AXIS).value()) + ")"; dumpUnaryInputOp(node, axis); } +void OperationDumper::visit(const Fill &node) +{ + VERBOSE(LIR) << "* " << node.name() << std::endl; + VERBOSE(LIR) << " - Inputs : Shape(" << node.getInputs().at(Fill::Input::SHAPE) << ") Value(" + << node.getInputs().at(Fill::Input::VALUE) << ")" << std::endl; + VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl; +} + void OperationDumper::visit(const FullyConnected &node) { - std::string inputs = - "Weight(" + std::to_string(node.getInputs().at(FullyConnected::Input::WEIGHT).value()) + - ") Bias(" + std::to_string(node.getInputs().at(FullyConnected::Input::BIAS).value()) + ")"; - dumpUnaryInputOp(node, inputs); + VERBOSE(LIR) << "* " << node.name() << std::endl; + VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs().at(ArgMinMax::INPUT) << ") Weight(" + << node.getInputs().at(FullyConnected::Input::WEIGHT) << ") Bias(" + << node.getInputs().at(FullyConnected::Input::BIAS) << ")" << std::endl; + VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl; } void OperationDumper::visit(const Gather &node) { std::string indices = - "Indices(" + std::to_string(node.getInputs().at(Gather::Input::INDICES).value()) + ")"; + "Indices(" + std::to_string(node.getInputs().at(Gather::Input::INDICES).value()) + ")"; dumpUnaryInputOp(node, indices); } @@ -174,50 +194,70 @@ void OperationDumper::visit(const HashtableLookup &node) void OperationDumper::visit(const InstanceNorm &node) { std::string inputs = - "Gamma(" + std::to_string(node.getInputs().at(InstanceNorm::Input::GAMMA).value()) + - ") Beta(" + std::to_string(node.getInputs().at(InstanceNorm::Input::BETA).value()) + ")"; + "Gamma(" + std::to_string(node.getInputs().at(InstanceNorm::Input::GAMMA).value()) + ") Beta(" + + std::to_string(node.getInputs().at(InstanceNorm::Input::BETA).value()) + ")"; dumpUnaryInputOp(node, inputs); } -void OperationDumper::visit(const L2Normalization &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const L2Normalization &node) { dumpOpGeneric(node); } -void OperationDumper::visit(const LocalResponseNormalization &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const LocalResponseNormalization &node) { dumpOpGeneric(node); } + +void OperationDumper::visit(const Loss &node) +{ + VERBOSE(LIR) << "* " << node.name() << std::endl; + VERBOSE(LIR) << " - Inputs : Prediction(" << node.getInputs().at(Loss::Input::Y_PRED) << ") True(" + << node.getInputs().at(Loss::Input::Y_TRUE) << ")" << std::endl; + VERBOSE(LIR) << " - Outputs : Output(" << node.getOutputs().at(0) << ")" << std::endl; +} void OperationDumper::visit(const LSTM &node) { + VERBOSE(LIR) << "* " << node.name() << std::endl; VERBOSE(LIR) - << " - Inputs : Input(" << node.getInputs().at(LSTM::Input::INPUT) - << ") Input To Input Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_INPUT_WEIGHTS) - << ") Input To Forget Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_FORGET_WEIGHTS) - << ") Input To Cell Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_CELL_WEIGHTS) - << ") Input To Output Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS) - << ") Recurrent To Input Weights(" - << node.getInputs().at(LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS) - << ") Recurrent To Forget Weights(" - << node.getInputs().at(LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS) - << ") Recurrent To Cell Weights(" - << node.getInputs().at(LSTM::Input::RECURRENT_TO_CELL_WEIGHTS) - << ") Recurrent To Output Weights(" - << node.getInputs().at(LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS) << ") Cell To Input Weights(" - << node.getInputs().at(LSTM::Input::CELL_TO_INPUT_WEIGHTS) << ") Cell To Forget Weights(" - << node.getInputs().at(LSTM::Input::CELL_TO_FORGET_WEIGHTS) << ") Cell To OUTPUT Weights(" - << node.getInputs().at(LSTM::Input::CELL_TO_OUTPUT_WEIGHTS) << ") Input Gate Bias(" - << node.getInputs().at(LSTM::Input::INPUT_GATE_BIAS) << ") Forget Gate Bias(" - << node.getInputs().at(LSTM::Input::FORGET_GATE_BIAS) << ") Cell Bias(" - << node.getInputs().at(LSTM::Input::CELL_BIAS) << ") Output Gate Bias(" - << node.getInputs().at(LSTM::Input::OUTPUT_GATE_BIAS) << ") Projection Weights(" - << node.getInputs().at(LSTM::Input::PROJECTION_WEIGHTS) << ") Projection Bias(" - << node.getInputs().at(LSTM::Input::PROJECTION_BIAS) << ") Output State In(" - << node.getInputs().at(LSTM::Input::OUTPUT_STATE_IN) << ") Cell State In(" - << node.getInputs().at(LSTM::Input::CELL_STATE_IN) << ")" << std::endl; + << " - Inputs : Input(" << node.getInputs().at(LSTM::Input::INPUT) + << ") Input To Input Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_INPUT_WEIGHTS) + << ") Input To Forget Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_FORGET_WEIGHTS) + << ") Input To Cell Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_CELL_WEIGHTS) + << ") Input To Output Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS) + << ") Recurrent To Input Weights(" + << node.getInputs().at(LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS) + << ") Recurrent To Forget Weights(" + << node.getInputs().at(LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS) + << ") Recurrent To Cell Weights(" << node.getInputs().at(LSTM::Input::RECURRENT_TO_CELL_WEIGHTS) + << ") Recurrent To Output Weights(" + << node.getInputs().at(LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS) << ") Cell To Input Weights(" + << node.getInputs().at(LSTM::Input::CELL_TO_INPUT_WEIGHTS) << ") Cell To Forget Weights(" + << node.getInputs().at(LSTM::Input::CELL_TO_FORGET_WEIGHTS) << ") Cell To OUTPUT Weights(" + << node.getInputs().at(LSTM::Input::CELL_TO_OUTPUT_WEIGHTS) << ") Input Gate Bias(" + << node.getInputs().at(LSTM::Input::INPUT_GATE_BIAS) << ") Forget Gate Bias(" + << node.getInputs().at(LSTM::Input::FORGET_GATE_BIAS) << ") Cell Bias(" + << node.getInputs().at(LSTM::Input::CELL_BIAS) << ") Output Gate Bias(" + << node.getInputs().at(LSTM::Input::OUTPUT_GATE_BIAS) << ") Projection Weights(" + << node.getInputs().at(LSTM::Input::PROJECTION_WEIGHTS) << ") Projection Bias(" + << node.getInputs().at(LSTM::Input::PROJECTION_BIAS) << ") Output State In(" + << node.getInputs().at(LSTM::Input::OUTPUT_STATE_IN) << ") Cell State In(" + << node.getInputs().at(LSTM::Input::CELL_STATE_IN); + if (node.getInputs().size() == 24) + { + VERBOSE(LIR) << ") Input Layer Normalization Weights(" + << node.getInputs().at(LSTM::Input::INPUT_LAYER_NORMALIZATION_WEIGHTS) + << ") Forget Layer Normalization Weights(" + << node.getInputs().at(LSTM::Input::FORGET_LAYER_NORMALIZATION_WEIGHTS) + << ") Cell Layer Normalization Weights(" + << node.getInputs().at(LSTM::Input::CELL_LAYER_NORMALIZATION_WEIGHTS) + << ") Ouput Layer Normalization Weights(" + << node.getInputs().at(LSTM::Input::OUTPUT_LAYER_NORMALIZATION_WEIGHTS); + } + VERBOSE(LIR) << ")" << std::endl; VERBOSE(LIR) << " - Output : Scratch Buffer(" << node.getOutputs().at(LSTM::Output::SCRATCH_BUFFER) << ") Output State Out(" - << node.getInputs().at(LSTM::Output::OUTPUT_STATE_OUT) << ") Cell State Out(" - << node.getInputs().at(LSTM::Output::CELL_STATE_OUT) << ") Output(" - << node.getInputs().at(LSTM::Output::OUTPUT) << ")" << std::endl; + << node.getOutputs().at(LSTM::Output::OUTPUT_STATE_OUT) << ") Cell State Out(" + << node.getOutputs().at(LSTM::Output::CELL_STATE_OUT) << ") Output(" + << node.getOutputs().at(LSTM::Output::OUTPUT) << ")" << std::endl; } -void OperationDumper::visit(const Pack &node) { dumpPackingOp(node); } +void OperationDumper::visit(const Pack &node) { dumpOpGeneric(node); } void OperationDumper::visit(const Pad &node) { @@ -249,23 +289,23 @@ void OperationDumper::visit(const Permute &node) void OperationDumper::visit(const Pool2D &node) { std::string padding_type = - node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit"; + node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit"; VERBOSE(LIR) << "* " << node.name() << "(" << padding_type << ")" << std::endl; VERBOSE(LIR) << " - Inputs : IFM(" << node.getInputs().at(Pool2D::Input::INPUT) << ")" << std::endl; VERBOSE(LIR) << " - Output : OFM(" << node.getOutputs().at(0) << ")" << std::endl; } -void OperationDumper::visit(const Pow &node) { dumpBinaryInputOp(node); } +void OperationDumper::visit(const Pow &node) { dumpOpGeneric(node); } void OperationDumper::visit(const PReLU &node) { std::string alpha = - "Alpha(" + std::to_string(node.getInputs().at(PReLU::Input::ALPHA).value()) + ")"; - dumpUnaryInputOp(node, alpha); + "Alpha(" + std::to_string(node.getInputs().at(PReLU::Input::ALPHA).value()) + ")"; + dumpOpGeneric(node, alpha); } -void OperationDumper::visit(const Rank &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const Rank &node) { dumpOpGeneric(node); } void OperationDumper::visit(const Reduce &node) { dumpUnaryInputOp(node); } @@ -273,18 +313,20 @@ void OperationDumper::visit(const Reshape &node) { // optional param std::string shape = - node.getInputs().size() == 2 - ? "Shape(" + std::to_string(node.getInputs().at(Reshape::Input::SHAPE).value()) + ")" - : "Shape(not provided)"; + node.getInputs().size() == 2 + ? "Shape(" + std::to_string(node.getInputs().at(Reshape::Input::SHAPE).value()) + ")" + : "Shape(not provided)"; dumpUnaryInputOp(node, shape); } -void OperationDumper::visit(const ResizeBilinear &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const ResizeBilinear &node) { dumpOpGeneric(node); } + +void OperationDumper::visit(const ResizeNearestNeighbor &node) { dumpOpGeneric(node); } void OperationDumper::visit(const Reverse &node) { std::string axis = - "Axis(" + std::to_string(node.getInputs().at(Reverse::Input::AXIS).value()) + ")"; + "Axis(" + std::to_string(node.getInputs().at(Reverse::Input::AXIS).value()) + ")"; dumpUnaryInputOp(node, axis); } @@ -320,25 +362,24 @@ void OperationDumper::visit(const Select &node) VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl; } -void OperationDumper::visit(const ir::operation::Shape &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const ir::operation::Shape &node) { dumpOpGeneric(node); } -void OperationDumper::visit(const Softmax &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const Softmax &node) { dumpOpGeneric(node); } void OperationDumper::visit(const SpaceToBatchND &node) { std::string inputs = - "BlockSize(" + - std::to_string(node.getInputs().at(SpaceToBatchND::Input::BLOCK_SIZE).value()) + - ") Paddings(" + std::to_string(node.getInputs().at(SpaceToBatchND::Input::PADDINGS).value()) + - ")"; + "BlockSize(" + std::to_string(node.getInputs().at(SpaceToBatchND::Input::BLOCK_SIZE).value()) + + ") Paddings(" + std::to_string(node.getInputs().at(SpaceToBatchND::Input::PADDINGS).value()) + + ")"; dumpUnaryInputOp(node, inputs); } -void OperationDumper::visit(const SpaceToDepth &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const SpaceToDepth &node) { dumpOpGeneric(node); } -void OperationDumper::visit(const Split &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const Split &node) { dumpOpGeneric(node); } -void OperationDumper::visit(const SquaredDifference &node) { dumpBinaryInputOp(node); } +void OperationDumper::visit(const SquaredDifference &node) { dumpOpGeneric(node); } void OperationDumper::visit(const StatelessRandomUniform &node) { @@ -349,7 +390,7 @@ void OperationDumper::visit(const StatelessRandomUniform &node) VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl; } -void OperationDumper::visit(const Squeeze &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const Squeeze &node) { dumpOpGeneric(node); } void OperationDumper::visit(const Slice &node) { dumpUnaryInputOp(node); } @@ -358,7 +399,7 @@ void OperationDumper::visit(const StridedSlice &node) { dumpUnaryInputOp(node); void OperationDumper::visit(const Tile &node) { std::string multiples = - "Multiples(" + std::to_string(node.getInputs().at(Tile::Input::MULTIPLES).value()) + ")"; + "Multiples(" + std::to_string(node.getInputs().at(Tile::Input::MULTIPLES).value()) + ")"; dumpUnaryInputOp(node, multiples); } @@ -375,7 +416,7 @@ void OperationDumper::visit(const TopKV2 &node) void OperationDumper::visit(const TransposeConv &node) { std::string padding_type = - node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit"; + node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit"; VERBOSE(LIR) << "* TransposeConv(" << padding_type << ")" << std::endl; VERBOSE(LIR) << " - Inputs : Output Shape(" << node.getInputs().at(TransposeConv::Input::OUTPUT_SHAPE) << ") KERNEL(" @@ -384,22 +425,14 @@ void OperationDumper::visit(const TransposeConv &node) VERBOSE(LIR) << " - Output : OFM(" << node.getOutputs().at(0) << ")" << std::endl; } -void OperationDumper::visit(const Transpose &node) { dumpUnaryInputOp(node); } +void OperationDumper::visit(const Transpose &node) { dumpOpGeneric(node); } void OperationDumper::visit(const Unpack &node) { VERBOSE(LIR) << "* " << node.name() << std::endl; VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs().at(Unpack::Input::INPUT) << ")" << std::endl; - std::string outputs; - const auto &output_indices = node.getOutputs(); - for (auto it = std::begin(output_indices); it != std::end(output_indices); ++it) - { - outputs += std::to_string(it->value()); - if (std::next(it) != std::end(output_indices)) - outputs += ", "; - } - VERBOSE(LIR) << " - Outputs : Outputs(" << outputs << ")" << std::endl; + VERBOSE(LIR) << " - Output : Outputs(" << node.getOutputs() << ")" << std::endl; } void OperationDumper::visit(const OneHot &node) @@ -413,51 +446,21 @@ void OperationDumper::visit(const OneHot &node) void OperationDumper::visit(const If &node) { VERBOSE(LIR) << "* " << node.name() << std::endl; - std::string inputs; - const auto &input_indices = node.getInputs(); - for (auto it = std::begin(input_indices); it != std::end(input_indices); ++it) - { - inputs += std::to_string(it->value()); - if (std::next(it) != std::end(input_indices)) - inputs += ", "; - } VERBOSE(LIR) << " - Inputs : " << "Then subgraph (" << node.param().then_subg_index << ") Else subgraph (" - << node.param().else_subg_index << ") Inputs(" << inputs << ")" << std::endl; - std::string outputs; - const auto &output_indices = node.getOutputs(); - for (auto it = std::begin(output_indices); it != std::end(output_indices); ++it) - { - outputs += std::to_string(it->value()); - if (std::next(it) != std::end(output_indices)) - outputs += ", "; - } - VERBOSE(LIR) << " - Output : Outputs(" << outputs << ")" << std::endl; + << node.param().else_subg_index << ") Inputs(" << node.getInputs() << ")" + << std::endl; + VERBOSE(LIR) << " - Output : Outputs(" << node.getOutputs() << ")" << std::endl; } void OperationDumper::visit(const While &node) { VERBOSE(LIR) << "* " << node.name() << std::endl; - std::string inputs; - const auto &input_indices = node.getInputs(); - for (auto it = std::begin(input_indices); it != std::end(input_indices); ++it) - { - inputs += std::to_string(it->value()); - if (std::next(it) != std::end(input_indices)) - inputs += ", "; - } VERBOSE(LIR) << " - Inputs : " << "Cond subgraph (" << node.param().cond_subg_index << ") Body subgraph (" - << node.param().cond_subg_index << ") Inputs(" << inputs << ")" << std::endl; - std::string outputs; - const auto &output_indices = node.getOutputs(); - for (auto it = std::begin(output_indices); it != std::end(output_indices); ++it) - { - outputs += std::to_string(it->value()); - if (std::next(it) != std::end(output_indices)) - outputs += ", "; - } - VERBOSE(LIR) << " - Output : Outputs(" << outputs << ")" << std::endl; + << node.param().body_subg_index << ") Inputs(" << node.getInputs() << ")" + << std::endl; + VERBOSE(LIR) << " - Output : Outputs(" << node.getOutputs() << ")" << std::endl; } } // namespace ir diff --git a/runtime/onert/core/src/ir/OperationDumper.h b/runtime/onert/core/src/ir/OperationDumper.h index e8ab3b3cd..99bf869d5 100644 --- a/runtime/onert/core/src/ir/OperationDumper.h +++ b/runtime/onert/core/src/ir/OperationDumper.h @@ -31,8 +31,9 @@ public: OperationDumper(const std::string &start_msg); public: - void visit(const operation::ArgMax &) override; + void visit(const operation::ArgMinMax &) override; void visit(const operation::BatchToSpaceND &node) override; + void visit(const operation::BCQFullyConnected &node) override; void visit(const operation::BinaryArithmetic &node) override; void visit(const operation::BroadcastTo &) override; void visit(const operation::Comparison &) override; @@ -47,12 +48,14 @@ public: void visit(const operation::ElementwiseUnary &) override; void visit(const operation::EmbeddingLookup &) override; void visit(const operation::ExpandDims &) override; + void visit(const operation::Fill &) override; void visit(const operation::FullyConnected &node) override; void visit(const operation::Gather &) override; void visit(const operation::HashtableLookup &) override; void visit(const operation::InstanceNorm &) override; void visit(const operation::L2Normalization &) override; void visit(const operation::LocalResponseNormalization &) override; + void visit(const operation::Loss &node) override; void visit(const operation::LSTM &) override; void visit(const operation::Pack &) override; void visit(const operation::Pad &) override; @@ -65,6 +68,7 @@ public: void visit(const operation::Reduce &) override; void visit(const operation::Reshape &node) override; void visit(const operation::ResizeBilinear &) override; + void visit(const operation::ResizeNearestNeighbor &) override; void visit(const operation::Reverse &) override; void visit(const operation::RNN &) override; void visit(const operation::Select &node) override; diff --git a/runtime/onert/core/src/ir/OperationValidator.cc b/runtime/onert/core/src/ir/OperationValidator.cc new file mode 100644 index 000000000..5598c4043 --- /dev/null +++ b/runtime/onert/core/src/ir/OperationValidator.cc @@ -0,0 +1,546 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OperationValidator.h" + +#include "ir/Graph.h" +#include "util/logging.h" + +#define OP_REQUIRES(EXP) \ + do \ + { \ + if (!(EXP)) \ + throw std::runtime_error("OperationValidator failed at line " + std::to_string(__LINE__)); \ + } while (0) + +namespace onert +{ +namespace ir +{ + +OperationValidator::OperationValidator(const Graph &graph) + : _operations{graph.operations()}, _operands{graph.operands()} +{ +} + +void OperationValidator::operator()() +{ + _operations.iterate([&](const OperationIndex &, const IOperation &node) { node.accept(*this); }); +} + +DataType OperationValidator::operandType(const OperandIndex &idx) +{ + return _operands.at(idx).typeInfo().type(); +} + +bool OperationValidator::isConstant(const OperandIndex &idx) +{ + return _operands.at(idx).isConstant(); +} + +bool OperationValidator::isSameType(const OperandIndex &idx1, const OperandIndex &idx2) +{ + return operandType(idx1) == operandType(idx2); +} + +bool OperationValidator::isSameQuantParam(const OperandIndex &idx1, const OperandIndex &idx2) +{ + if (_operands.at(idx1).typeInfo().scale() != _operands.at(idx2).typeInfo().scale()) + return false; + + if (_operands.at(idx1).typeInfo().zero_point() != _operands.at(idx2).typeInfo().zero_point()) + return false; + + return true; +} + +bool OperationValidator::isValidType(const OperandIndex &idx, const DataType &type) +{ + return operandType(idx) == type; +} + +bool OperationValidator::isValidType(const OperandIndex &idx, + std::initializer_list<DataType> valid_types) +{ + for (auto &&type_to_check : valid_types) + { + if (isValidType(idx, type_to_check)) + { + return true; + } + } + + return false; +} + +void OperationValidator::visit(const operation::AddN &node) +{ + const auto output_index(node.getOutputs().at(0)); + + int size = node.getInputs().size(); + for (int i = 0; i < size; i++) + { + const auto input_index(node.getInputs().at(i)); + OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32})); + OP_REQUIRES(isSameType(input_index, output_index)); + } +} + +void OperationValidator::visit(const operation::ArgMinMax &node) +{ + const auto input_index(node.getInputs().at(operation::ArgMinMax::Input::INPUT)); + const auto axis_index(node.getInputs().at(operation::ArgMinMax::Input::AXIS)); + const auto output_index(node.getOutputs().at(0)); + const auto output_type = node.param().output_type; + + OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::UINT8, + DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM})); + OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64})); + OP_REQUIRES(isValidType(output_index, {DataType::INT32, DataType::INT64})); + OP_REQUIRES(isValidType(output_index, output_type)); +} + +void OperationValidator::visit(const operation::BatchMatMul &node) +{ + const auto lhs_index(node.getInputs().at(operation::BatchMatMul::Input::LHS)); + const auto rhs_index(node.getInputs().at(operation::BatchMatMul::Input::RHS)); + const auto output_index(node.getOutputs().at(0)); + + // Constant lhs and rhs is not implemented yet + OP_REQUIRES(!isConstant(lhs_index) && !isConstant(rhs_index)); + + // Allow hybrid quantization (lhs: float / rhs: qint8 / out: float) + OP_REQUIRES(isValidType( + lhs_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM})); + OP_REQUIRES(isSameType(lhs_index, rhs_index) || + ((operandType(lhs_index) == DataType::FLOAT32) && + (operandType(rhs_index) == DataType::QUANT_INT8_ASYMM))); + OP_REQUIRES(isSameType(lhs_index, output_index)); +} + +void OperationValidator::visit(const operation::BatchToSpaceND &node) +{ + const auto input_index{node.getInputs().at(operation::BatchToSpaceND::Input::INPUT)}; + const auto output_index{node.getOutputs().at(0)}; + + OP_REQUIRES(isSameType(input_index, output_index)); +} + +void OperationValidator::visit(const operation::BinaryArithmetic &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto lhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::LHS)}; + const auto rhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::RHS)}; + + OP_REQUIRES(isSameType(lhs_index, rhs_index)); + OP_REQUIRES(isSameType(lhs_index, output_index)); +} + +void OperationValidator::visit(const operation::Comparison &node) +{ + const auto output_index{node.getOutputs().at(0)}; + + const auto lhs_index{node.getInputs().at(operation::Comparison::Input::INPUT0)}; + const auto rhs_index{node.getInputs().at(operation::Comparison::Input::INPUT1)}; + + OP_REQUIRES(isSameType(lhs_index, rhs_index)); + OP_REQUIRES(isValidType(output_index, DataType::BOOL8)); +} + +void OperationValidator::visit(const operation::Concat &node) +{ + const auto output_index{node.getOutputs().at(0)}; + + for (auto &&input_index : node.getInputs()) + { + OP_REQUIRES(isSameType(input_index, output_index)); + + // Int8 quantization requires same scale and zero point + if (isValidType(output_index, DataType::QUANT_INT8_ASYMM)) + { + OP_REQUIRES(isSameQuantParam(input_index, output_index)); + } + } +} + +void OperationValidator::visit(const operation::Conv2D &node) +{ + const auto input_index{node.getInputs().at(operation::Conv2D::Input::INPUT)}; + const auto kernel_index{node.getInputs().at(operation::Conv2D::Input::KERNEL)}; + const auto output_index{node.getOutputs().at(0)}; + + uint32_t stride_horizontal = node.param().stride.horizontal; + uint32_t stride_vertical = node.param().stride.vertical; + uint32_t dilation_width = node.param().dilation.width_factor; + uint32_t dilation_height = node.param().dilation.height_factor; + + OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0)); + OP_REQUIRES((dilation_width > 0) && (dilation_height > 0)); + OP_REQUIRES(isSameType(input_index, output_index)); + + if (isConstant(kernel_index) && operandType(kernel_index) == DataType::QUANT_INT8_ASYMM) + { + for (const auto zeropoint : _operands.at(kernel_index).typeInfo().zero_points()) + OP_REQUIRES(zeropoint == 0); + } +} + +void OperationValidator::visit(const operation::DepthToSpace &node) +{ + const auto input_index{node.getInputs().at(operation::DepthToSpace::Input::INPUT)}; + const auto output_index{node.getOutputs().at(0)}; + + int32_t block_size = node.param().block_size; + + OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::INT64, + DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM})); + OP_REQUIRES(isSameType(input_index, output_index)); + + OP_REQUIRES(block_size > 0); +} + +void OperationValidator::visit(const operation::DetectionPostProcess &node) +{ + const auto ¶m = node.param(); + + // FIXME: number of classes should be 1 for now. + OP_REQUIRES(param.num_classes == 1); +} + +void OperationValidator::visit(const operation::DepthwiseConv2D &node) +{ + const auto input_index{node.getInputs().at(operation::DepthwiseConv2D::Input::INPUT)}; + const auto kernel_index{node.getInputs().at(operation::DepthwiseConv2D::Input::KERNEL)}; + const auto output_index{node.getOutputs().at(0)}; + + uint32_t stride_horizontal = node.param().stride.horizontal; + uint32_t stride_vertical = node.param().stride.vertical; + uint32_t dilation_width = node.param().dilation.width_factor; + uint32_t dilation_height = node.param().dilation.height_factor; + + OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0)); + OP_REQUIRES((dilation_width > 0) && (dilation_height > 0)); + OP_REQUIRES(isSameType(input_index, output_index)); + + if (isConstant(kernel_index) && operandType(kernel_index) == DataType::QUANT_INT8_ASYMM) + { + for (const auto zeropoint : _operands.at(kernel_index).typeInfo().zero_points()) + OP_REQUIRES(zeropoint == 0); + } +} + +void OperationValidator::visit(const operation::ElementwiseActivation &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(0)}; + + // Check if I/O types match + OP_REQUIRES(isSameType(output_index, input_index)); + + switch (node.param().op_type) + { + case operation::ElementwiseActivation::Type::ELU: + OP_REQUIRES(isValidType(input_index, DataType::FLOAT32)); + break; + case operation::ElementwiseActivation::Type::LEAKY_RELU: + OP_REQUIRES( + isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, + DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM})); + break; + case operation::ElementwiseActivation::Type::LOGISTIC: + OP_REQUIRES( + isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, + DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM})); + break; + case operation::ElementwiseActivation::Type::RELU: + OP_REQUIRES(isValidType( + input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM})); + break; + case operation::ElementwiseActivation::Type::TANH: + OP_REQUIRES( + isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, + DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM})); + break; + } +} + +void OperationValidator::visit(const operation::ElementwiseBinary &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto lhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::LHS)}; + const auto rhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::RHS)}; + + OP_REQUIRES(isSameType(lhs_index, rhs_index)); + OP_REQUIRES(isSameType(lhs_index, output_index)); + + const auto op_type = node.param().op_type; + if (op_type == operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_AND || + op_type == operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_OR) + { + OP_REQUIRES(isValidType(lhs_index, DataType::BOOL8)); + } +} + +void OperationValidator::visit(const operation::ElementwiseUnary &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(operation::ElementwiseUnary::Input::INPUT)}; + + // Check if I/O types match + if (node.param().op_type == operation::ElementwiseUnary::Type::DEQUANTIZE) + { + // NNAPI allow QUANT_INT8_SYMM type input + OP_REQUIRES(isValidType(input_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_SYMM, + DataType::QUANT_INT8_ASYMM})); + OP_REQUIRES(isValidType(output_index, DataType::FLOAT32)); + } + else if (node.param().op_type == operation::ElementwiseUnary::Type::QUANTIZE) + { + OP_REQUIRES(isValidType( + input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM})); + OP_REQUIRES( + isValidType(output_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM})); + } + else if (node.param().op_type == operation::ElementwiseUnary::Type::FLOOR) + { + OP_REQUIRES(isValidType(input_index, DataType::FLOAT32)); + OP_REQUIRES(isSameType(output_index, input_index)); + } + else if (node.param().op_type != operation::ElementwiseUnary::Type::CAST) + { + OP_REQUIRES(isSameType(output_index, input_index)); + } +} + +void OperationValidator::visit(const operation::EmbeddingLookup &node) +{ + const auto lookups_index{node.getInputs().at(operation::EmbeddingLookup::Input::LOOKUPS)}; + const auto values_index{node.getInputs().at(operation::EmbeddingLookup::Input::VALUES)}; + const auto output_index{node.getOutputs().at(0)}; + + OP_REQUIRES(isValidType(lookups_index, DataType::INT32)); + + // TFLite: Allow hybrid type - value table & output + // NNAPI: Require same value table and output type + OP_REQUIRES( + isSameType(values_index, output_index) || + (isValidType(output_index, DataType::FLOAT32) && + (isValidType(values_index, {DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT8_SYMM})))); +} + +void OperationValidator::visit(const operation::ExpandDims &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(operation::ExpandDims::Input::INPUT)}; + const auto axis_index{node.getInputs().at(operation::ExpandDims::Input::AXIS)}; + + OP_REQUIRES(isSameType(output_index, input_index)); + OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64})); +} + +void OperationValidator::visit(const operation::Fill &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(operation::Fill::Input::SHAPE)}; + const auto value_index{node.getInputs().at(operation::Fill::Input::VALUE)}; + + OP_REQUIRES(isSameType(output_index, value_index)); + OP_REQUIRES(isValidType(input_index, {DataType::INT32, DataType::INT64})); + OP_REQUIRES(isValidType(output_index, + {DataType::FLOAT32, DataType::INT32, DataType::INT64, DataType::BOOL8})); +} + +void OperationValidator::visit(const operation::HashtableLookup &node) +{ + const auto hits_index{node.getOutputs().at(operation::HashtableLookup::Output::HITS)}; + const auto lookups_index{node.getInputs().at(operation::HashtableLookup::Input::LOOKUPS)}; + const auto keys_index{node.getInputs().at(operation::HashtableLookup::Input::KEYS)}; + + OP_REQUIRES(isValidType(lookups_index, DataType::INT32)); + OP_REQUIRES(isValidType(keys_index, DataType::INT32)); + OP_REQUIRES(isValidType(hits_index, DataType::QUANT_UINT8_ASYMM)); +} + +void OperationValidator::visit(const operation::Pack &node) +{ + const auto num{node.param().num}; + + OP_REQUIRES(num == static_cast<int32_t>(node.getInputs().size())); +} + +void OperationValidator::visit(const operation::Pad &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(operation::Pad::Input::INPUT)}; + const auto pad_index{node.getInputs().at(operation::Pad::Input::PAD)}; + bool isQuantType = + isValidType(output_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}); + bool isPadV2 = node.getInputs().size() == 3 ? true : false; + + OP_REQUIRES(isValidType(pad_index, DataType::INT32)); + OP_REQUIRES(isSameType(input_index, output_index)); + + if (isQuantType) + OP_REQUIRES(isSameQuantParam(input_index, output_index)); + + if (isPadV2) + { + const auto value_index{node.getInputs().at(operation::Pad::Input::VALUE)}; + const bool cond_same = isSameType(input_index, value_index); + const bool cond_same_quant = (!isQuantType || isSameQuantParam(input_index, value_index)); + const auto input_t = operandType(input_index); + const auto value_t = operandType(value_index); + // NNAPI accepts this case. scale and zeroPoint are assumed to be the same as in input0. + const bool cond_quant8 = + ((input_t == DataType::QUANT_UINT8_ASYMM || input_t == DataType::QUANT_INT8_ASYMM) && + value_t == DataType::INT32); + OP_REQUIRES((cond_same && cond_same_quant) || cond_quant8); + } +} + +void OperationValidator::visit(const operation::Rank &node) +{ + const auto output_index{node.getOutputs().at(0)}; + + OP_REQUIRES(isValidType(output_index, DataType::INT32)); +} + +void OperationValidator::visit(const operation::ResizeBilinear &node) +{ + auto align_corners = node.param().align_corners; + auto half_pixel_centers = node.param().half_pixel_centers; + + OP_REQUIRES(!align_corners || !half_pixel_centers); +} + +void OperationValidator::visit(const operation::Reverse &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(operation::Reverse::Input::INPUT)}; + const auto axis_index{node.getInputs().at(operation::Reverse::Input::AXIS)}; + + OP_REQUIRES(isValidType(axis_index, DataType::INT32)); + OP_REQUIRES(isSameType(output_index, input_index)); +} + +void OperationValidator::visit(const operation::Select &node) +{ + const auto condition_index{node.getInputs().at(operation::Select::Input::CONDITION)}; + const auto input_true_index{node.getInputs().at(operation::Select::Input::INPUT_TRUE)}; + const auto input_false_index{node.getInputs().at(operation::Select::Input::INPUT_FALSE)}; + + OP_REQUIRES(isValidType(condition_index, DataType::BOOL8)); + OP_REQUIRES(isSameType(input_true_index, input_false_index)); +} + +void OperationValidator::visit(const operation::Shape &node) +{ + const auto output_index{node.getOutputs().at(0)}; + + OP_REQUIRES(isValidType(output_index, {DataType::UINT32, DataType::INT32, DataType::INT64})); +} + +void OperationValidator::visit(const operation::Slice &node) +{ + const auto begins_index{node.getInputs().at(operation::Slice::BEGINS)}; + const auto sizes_index{node.getInputs().at(operation::Slice::SIZES)}; + + OP_REQUIRES(isValidType(begins_index, {DataType::INT32, DataType::INT64})); + OP_REQUIRES(isSameType(begins_index, sizes_index)); +} + +void OperationValidator::visit(const operation::Softmax &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(operation::Softmax::INPUT)}; + + OP_REQUIRES(isSameType(input_index, output_index)); + OP_REQUIRES(isValidType( + output_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM})); +} + +void OperationValidator::visit(const operation::SpaceToBatchND &node) +{ + const auto block_size_index{node.getInputs().at(operation::SpaceToBatchND::Input::BLOCK_SIZE)}; + const auto paddings_index{node.getInputs().at(operation::SpaceToBatchND::Input::PADDINGS)}; + + // Non-constant block_size and padding is not implemented yet + OP_REQUIRES(isConstant(block_size_index)); + OP_REQUIRES(isConstant(paddings_index)); +} + +void OperationValidator::visit(const operation::SpaceToDepth &node) +{ + const auto block_size = node.param().block_size; + OP_REQUIRES(block_size >= 1); +} + +void OperationValidator::visit(const operation::Split &node) +{ + const auto num_splits = node.param().num_splits; + + OP_REQUIRES(num_splits > 0 && num_splits <= 0xFFFF); + OP_REQUIRES(node.getOutputs().size() == static_cast<uint32_t>(num_splits)); +} + +void OperationValidator::visit(const operation::SquaredDifference &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto lhs_index{node.getInputs().at(operation::SquaredDifference::Input::LHS)}; + const auto rhs_index{node.getInputs().at(operation::SquaredDifference::Input::RHS)}; + + OP_REQUIRES(isSameType(output_index, lhs_index)); + OP_REQUIRES(isSameType(lhs_index, rhs_index)); +} + +void OperationValidator::visit(const operation::StatelessRandomUniform &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto shape_index{node.getInputs().at(operation::StatelessRandomUniform::Input::SHAPE)}; + const auto seed_index{node.getInputs().at(operation::StatelessRandomUniform::Input::SEED)}; + + OP_REQUIRES(isValidType(output_index, DataType::FLOAT32)); + OP_REQUIRES(isValidType(shape_index, DataType::INT32)); + OP_REQUIRES(isValidType(seed_index, DataType::INT32)); +} + +void OperationValidator::visit(const operation::StridedSlice &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(operation::StridedSlice::Input::INPUT)}; + + OP_REQUIRES(isSameType(output_index, input_index)); +} + +void OperationValidator::visit(const operation::TransposeConv &node) +{ + OP_REQUIRES((node.param().padding.type == PaddingType::SAME) || + (node.param().padding.type == PaddingType::VALID)); +} + +void OperationValidator::visit(const operation::Unpack &node) +{ + const auto num{node.param().num}; + OP_REQUIRES(num == static_cast<int32_t>(node.getOutputs().size())); +} + +void OperationValidator::visit(const operation::While &node) +{ + OP_REQUIRES(node.getInputs().size() == node.getOutputs().size()); +} + +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/OperationValidator.h b/runtime/onert/core/src/ir/OperationValidator.h new file mode 100644 index 000000000..b9bcc4ee8 --- /dev/null +++ b/runtime/onert/core/src/ir/OperationValidator.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_OPERATION_VALIDATOR_H__ +#define __ONERT_IR_OPERATION_VALIDATOR_H__ + +#include "ir/OperationVisitor.h" +#include "ir/Operations.h" +#include "ir/Operands.h" + +namespace onert +{ +namespace ir +{ +class Graph; +class Operands; +} // namespace ir +} // namespace onert + +namespace onert +{ +namespace ir +{ + +class OperationValidator : public OperationVisitor +{ +public: + OperationValidator(void) = delete; + OperationValidator(const Graph &graph); + +public: + void operator()(); + +public: + void visit(const operation::AddN &node) override; + void visit(const operation::ArgMinMax &node) override; + void visit(const operation::BatchMatMul &node) override; + void visit(const operation::BatchToSpaceND &node) override; + void visit(const operation::BinaryArithmetic &node) override; + void visit(const operation::Comparison &node) override; + void visit(const operation::Concat &node) override; + void visit(const operation::Conv2D &node) override; + void visit(const operation::DepthToSpace &node) override; + void visit(const operation::DepthwiseConv2D &node) override; + void visit(const operation::DetectionPostProcess &node) override; + void visit(const operation::ElementwiseActivation &node) override; + void visit(const operation::ElementwiseBinary &node) override; + void visit(const operation::ElementwiseUnary &node) override; + void visit(const operation::EmbeddingLookup &node) override; + void visit(const operation::ExpandDims &node) override; + void visit(const operation::Fill &node) override; + void visit(const operation::HashtableLookup &node) override; + void visit(const operation::Pack &node) override; + void visit(const operation::Pad &node) override; + void visit(const operation::Rank &node) override; + void visit(const operation::ResizeBilinear &node) override; + void visit(const operation::Reverse &node) override; + void visit(const operation::Select &node) override; + void visit(const operation::Shape &node) override; + void visit(const operation::Slice &node) override; + void visit(const operation::Softmax &node) override; + void visit(const operation::SpaceToBatchND &node) override; + void visit(const operation::SpaceToDepth &node) override; + void visit(const operation::Split &node) override; + void visit(const operation::SquaredDifference &node) override; + void visit(const operation::StatelessRandomUniform &node) override; + void visit(const operation::StridedSlice &node) override; + void visit(const operation::TransposeConv &node) override; + void visit(const operation::Unpack &node) override; + void visit(const operation::While &node) override; + +private: + DataType operandType(const OperandIndex &idx); + bool isConstant(const OperandIndex &idx); + bool isSameType(const OperandIndex &idx1, const OperandIndex &idx2); + bool isSameQuantParam(const OperandIndex &idx1, const OperandIndex &idx2); + bool isValidType(const OperandIndex &idx, const DataType &type); + bool isValidType(const OperandIndex &idx, std::initializer_list<DataType> valid_types); + +private: + const Operations &_operations; + const Operands &_operands; +}; + +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_OPERATION_VALIDATOR_H__ diff --git a/runtime/onert/core/src/ir/Operations.cc b/runtime/onert/core/src/ir/Operations.cc index 64d0bd6f0..1b4691f58 100644 --- a/runtime/onert/core/src/ir/Operations.cc +++ b/runtime/onert/core/src/ir/Operations.cc @@ -25,12 +25,9 @@ namespace ir Operations::Operations(const Operations &obj) { - obj.iterate([&](const OperationIndex &index, const Operation &op) { - OperationCloner cloner; - op.accept(cloner); - _objects.emplace(index, cloner.releaseClone()); - }); - _index_count = obj._index_count; + obj.iterate( + [&](const OperationIndex &index, const IOperation &op) { _objects.emplace(index, clone(op)); }); + _next_index = obj._next_index; } } // namespace ir diff --git a/runtime/onert/core/src/ir/Operations.test.cc b/runtime/onert/core/src/ir/Operations.test.cc new file mode 100644 index 000000000..e57872689 --- /dev/null +++ b/runtime/onert/core/src/ir/Operations.test.cc @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/Operations.h" + +#include "MockNode.h" + +#include <gtest/gtest.h> + +using onert::ir::Operation; +using onert::ir::OperationIndex; +using onert::ir::Operations; + +TEST(ir_Operations, basic) +{ + Operations ops; + ops.push(std::unique_ptr<Operation>(new onert_test::ir::SimpleMock({1, 2, 3, 4}, {5, 6, 7}))); + OperationIndex idx{0u}; + ASSERT_EQ(ops.at(idx).getInputs().size(), 4); + ASSERT_EQ(ops.at(idx).getOutputs().size(), 3); +} + +TEST(ir_Operations, neg_at) +{ + Operations ops; + ops.push(std::unique_ptr<Operation>(new onert_test::ir::SimpleMock({1, 2, 3, 4}, {5, 6, 7}))); + OperationIndex idx{99u}; + EXPECT_THROW(ops.at(idx), std::out_of_range); +} diff --git a/runtime/onert/core/src/ir/Padding.cc b/runtime/onert/core/src/ir/Padding.cc index d74f80217..b2b004e7a 100644 --- a/runtime/onert/core/src/ir/Padding.cc +++ b/runtime/onert/core/src/ir/Padding.cc @@ -66,14 +66,14 @@ inline ExplicitPadding samePaddingUsingIFM(const FeatureShape &ifm_shape, const const int32_t vertical_expected_output = (ifm_shape.H + stride.vertical - 1) / stride.vertical; const int32_t horizontal_expected_output = - (ifm_shape.W + stride.horizontal - 1) / stride.horizontal; + (ifm_shape.W + stride.horizontal - 1) / stride.horizontal; const int32_t vertical_needed_input = - (vertical_expected_output - 1) * stride.vertical + effective_filter_h_size; + (vertical_expected_output - 1) * stride.vertical + effective_filter_h_size; const int32_t vertical_total_padding = std::max(0, vertical_needed_input - ifm_shape.H); const int32_t horizontal_needed_input = - (horizontal_expected_output - 1) * stride.horizontal + effective_filter_w_size; + (horizontal_expected_output - 1) * stride.horizontal + effective_filter_w_size; const int32_t horizontal_total_padding = std::max(0, horizontal_needed_input - ifm_shape.W); padding.top = vertical_total_padding / 2; @@ -90,7 +90,7 @@ inline ExplicitPadding samePadding(const FeatureShape &ifm_shape, const FeatureS { const int32_t vertical_expected_output = (ifm_shape.H + stride.vertical - 1) / stride.vertical; const int32_t horizontal_expected_output = - (ifm_shape.W + stride.horizontal - 1) / stride.horizontal; + (ifm_shape.W + stride.horizontal - 1) / stride.horizontal; assert(vertical_expected_output == ofm_shape.H); assert(horizontal_expected_output == ofm_shape.W); @@ -129,7 +129,7 @@ Padding::Padding(PaddingType paddingType) : type{paddingType}, param{0, 0, 0, 0} } Padding::Padding(uint32_t left, uint32_t right, uint32_t top, uint32_t bottom) - : type{PaddingType::EXPLICIT}, param{left, right, top, bottom} + : type{PaddingType::EXPLICIT}, param{left, right, top, bottom} { // DO NOTHING } diff --git a/runtime/onert/core/src/ir/Shape.cc b/runtime/onert/core/src/ir/Shape.cc index 322df7b4c..e4e4c154b 100644 --- a/runtime/onert/core/src/ir/Shape.cc +++ b/runtime/onert/core/src/ir/Shape.cc @@ -26,10 +26,10 @@ namespace onert namespace ir { -int32_t const Shape::UNSPECIFIED_DIM = -1; +int32_t const Shape::kUnspecifiedDim = -1; // NNFW_MAX_RANK is 6 -int32_t const Shape::MAX_RANK = 6; +int32_t const Shape::kMaxRank = 6; FeatureShape Shape::asFeature(Layout layout) const { @@ -80,34 +80,37 @@ uint64_t Shape::num_elements() const { // if dimension is 0, it means unspecified and cannot calculate the total number of elements if (std::any_of(_dimensions.begin(), _dimensions.end(), - [](const int32_t &v) { return v == UNSPECIFIED_DIM; })) + [](const int32_t &v) { return v == kUnspecifiedDim; })) throw std::runtime_error("num_elements() cannot calculate when any dimension is unspecified"); return std::accumulate(_dimensions.cbegin(), _dimensions.cend(), UINT64_C(1), std::multiplies<uint64_t>()); } -Shape permuteShape(const Shape &shape, Layout frontend_layout, Layout backend_layout) +Shape permuteShape(const Shape &shape, Layout from, Layout to) { - assert(shape.rank() <= Shape::MAX_RANK); - Shape backend_shape{shape}; - if (shape.rank() >= 4 && frontend_layout == Layout::NHWC && backend_layout == Layout::NCHW) + assert(shape.rank() <= Shape::kMaxRank); + Shape ret{shape}; + if (from == to) + return ret; + if (shape.rank() < 4) + return ret; + // Permutation changing layout beyond 4-D is not supported yet + assert(shape.rank() <= 4); + if (from == Layout::NHWC && to == Layout::NCHW) { - // Permutation changing layout beyond 4-D is not supported yet - assert(shape.rank() <= 4); - backend_shape.dim(1) = shape.dim(3); - backend_shape.dim(2) = shape.dim(1); - backend_shape.dim(3) = shape.dim(2); + ret.dim(1) = shape.dim(3); + ret.dim(2) = shape.dim(1); + ret.dim(3) = shape.dim(2); } - else if (shape.rank() >= 4 && frontend_layout == Layout::NCHW && backend_layout == Layout::NHWC) + else if (from == Layout::NCHW && to == Layout::NHWC) { - // Permutation changing layout beyond 4-D is not supported yet - assert(shape.rank() <= 4); - backend_shape.dim(1) = shape.dim(2); - backend_shape.dim(2) = shape.dim(3); - backend_shape.dim(3) = shape.dim(1); + ret.dim(1) = shape.dim(2); + ret.dim(2) = shape.dim(3); + ret.dim(3) = shape.dim(1); } - return backend_shape; + // Other cases(either `from` or `to` is UNKNOWN), just return the original shape + return ret; } } // namespace ir diff --git a/runtime/onert/core/src/ir/Shape.test.cc b/runtime/onert/core/src/ir/Shape.test.cc new file mode 100644 index 000000000..4788522d3 --- /dev/null +++ b/runtime/onert/core/src/ir/Shape.test.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/Shape.h" + +#include <gtest/gtest.h> + +TEST(ShapeTest, basic_test) +{ + { + onert::ir::Shape shape(3); + + shape.dim(0) = 1; + shape.dim(1) = 2; + shape.dim(2) = 3; + + ASSERT_EQ(shape.rank(), 3); + ASSERT_EQ(shape.num_elements(), 6); + ASSERT_EQ(onert::ir::rankMaybeUnspecified(shape), false); + ASSERT_EQ(shape.hasUnspecifiedDims(), false); + } + { + onert::ir::Shape shape; // scalar or rank is unspecified + + ASSERT_EQ(shape.rank(), 0); + ASSERT_EQ(shape.num_elements(), 1); + ASSERT_EQ(onert::ir::rankMaybeUnspecified(shape), true); + ASSERT_EQ(shape.hasUnspecifiedDims(), false); + } +} + +TEST(ShapeTest, neg_basic_test) +{ + { + onert::ir::Shape shape(2); + + shape.dim(0) = 1; + shape.dim(1) = onert::ir::Shape::kUnspecifiedDim; + + ASSERT_EQ(shape.rank(), 2); + ASSERT_EQ(onert::ir::rankMaybeUnspecified(shape), false); + ASSERT_EQ(shape.hasUnspecifiedDims(), true); + EXPECT_ANY_THROW(shape.num_elements()); + } +} diff --git a/runtime/onert/core/src/ir/TypeInfo.cc b/runtime/onert/core/src/ir/TypeInfo.cc index ab8af287e..5d1c7ba8b 100644 --- a/runtime/onert/core/src/ir/TypeInfo.cc +++ b/runtime/onert/core/src/ir/TypeInfo.cc @@ -28,7 +28,7 @@ bool operator==(const TypeInfo &lhs, const TypeInfo &rhs) return false; } - if (lhs.offset() != rhs.offset()) + if (lhs.zero_point() != rhs.zero_point()) { return false; } diff --git a/runtime/onert/core/src/ir/operation/AddN.cc b/runtime/onert/core/src/ir/operation/AddN.cc new file mode 100644 index 000000000..a51e12dff --- /dev/null +++ b/runtime/onert/core/src/ir/operation/AddN.cc @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/operation/AddN.h" +#include "ir/OperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace operation +{ + +void AddN::accept(OperationVisitor &v) const { v.visit(*this); } + +AddN::AddN(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) + : Operation{OperandConstraint::createExact(inputs.size()), inputs, outputs} +{ +} + +} // namespace operation +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/operation/ArgMax.cc b/runtime/onert/core/src/ir/operation/ArgMinMax.cc index 1275ae43a..2f18ff2e2 100644 --- a/runtime/onert/core/src/ir/operation/ArgMax.cc +++ b/runtime/onert/core/src/ir/operation/ArgMinMax.cc @@ -14,10 +14,7 @@ * limitations under the License. */ -#include "ir/operation/ArgMax.h" - -#include <cassert> - +#include "ir/operation/ArgMinMax.h" #include "ir/OperationVisitor.h" namespace onert @@ -27,11 +24,11 @@ namespace ir namespace operation { -void ArgMax::accept(OperationVisitor &v) const { v.visit(*this); } +void ArgMinMax::accept(OperationVisitor &v) const { v.visit(*this); } -ArgMax::ArgMax(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, - const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} +ArgMinMax::ArgMinMax(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, + const Param ¶m) + : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/BCQFullyConnected.cc b/runtime/onert/core/src/ir/operation/BCQFullyConnected.cc index 9dc54e6e9..ccda674ad 100644 --- a/runtime/onert/core/src/ir/operation/BCQFullyConnected.cc +++ b/runtime/onert/core/src/ir/operation/BCQFullyConnected.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/BCQFullyConnected.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void BCQFullyConnected::accept(OperationVisitor &v) const { v.visit(*this); } BCQFullyConnected::BCQFullyConnected(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(5u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(5u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/BCQGather.cc b/runtime/onert/core/src/ir/operation/BCQGather.cc index 80efa6460..1ca5b0c9f 100644 --- a/runtime/onert/core/src/ir/operation/BCQGather.cc +++ b/runtime/onert/core/src/ir/operation/BCQGather.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/BCQGather.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void BCQGather::accept(OperationVisitor &v) const { v.visit(*this); } BCQGather::BCQGather(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(4u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(4u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/BatchMatMul.cc b/runtime/onert/core/src/ir/operation/BatchMatMul.cc index b9616158d..20c5682f9 100644 --- a/runtime/onert/core/src/ir/operation/BatchMatMul.cc +++ b/runtime/onert/core/src/ir/operation/BatchMatMul.cc @@ -28,7 +28,7 @@ void BatchMatMul::accept(OperationVisitor &v) const { v.visit(*this); } BatchMatMul::BatchMatMul(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/BatchToSpaceND.cc b/runtime/onert/core/src/ir/operation/BatchToSpaceND.cc index 9ef2b125f..3c5578ac4 100644 --- a/runtime/onert/core/src/ir/operation/BatchToSpaceND.cc +++ b/runtime/onert/core/src/ir/operation/BatchToSpaceND.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/BatchToSpaceND.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void BatchToSpaceND::accept(OperationVisitor &v) const { v.visit(*this); } BatchToSpaceND::BatchToSpaceND(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(3u), inputs, outputs} + : Operation{OperandConstraint::createInRange(2u, 3u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/BinaryArithmetic.cc b/runtime/onert/core/src/ir/operation/BinaryArithmetic.cc index 2b1422c73..5eb3fc3d7 100644 --- a/runtime/onert/core/src/ir/operation/BinaryArithmetic.cc +++ b/runtime/onert/core/src/ir/operation/BinaryArithmetic.cc @@ -15,12 +15,10 @@ */ #include "ir/operation/BinaryArithmetic.h" +#include "ir/OperationVisitor.h" -#include <cassert> #include <unordered_map> -#include "ir/OperationVisitor.h" - namespace onert { namespace ir @@ -32,7 +30,7 @@ void BinaryArithmetic::accept(OperationVisitor &v) const { v.visit(*this); } BinaryArithmetic::BinaryArithmetic(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} { } @@ -40,10 +38,10 @@ std::string BinaryArithmetic::name() const { using ArithmeticType = onert::ir::operation::BinaryArithmetic::ArithmeticType; static const std::unordered_map<ArithmeticType, std::string> name_map{ - {ArithmeticType::ADD, std::string{"Add"}}, - {ArithmeticType::SUB, std::string{"Sub"}}, - {ArithmeticType::MUL, std::string{"Mul"}}, - {ArithmeticType::DIV, std::string{"Div"}}}; + {ArithmeticType::ADD, std::string{"Add"}}, + {ArithmeticType::SUB, std::string{"Sub"}}, + {ArithmeticType::MUL, std::string{"Mul"}}, + {ArithmeticType::DIV, std::string{"Div"}}}; return name_map.at(_param.arithmetic_type); } diff --git a/runtime/onert/core/src/ir/operation/BroadcastTo.cc b/runtime/onert/core/src/ir/operation/BroadcastTo.cc index a8f5e59cf..eab6c0611 100644 --- a/runtime/onert/core/src/ir/operation/BroadcastTo.cc +++ b/runtime/onert/core/src/ir/operation/BroadcastTo.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/BroadcastTo.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -29,7 +26,7 @@ namespace operation void BroadcastTo::accept(OperationVisitor &v) const { v.visit(*this); } BroadcastTo::BroadcastTo(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(2u), inputs, outputs} + : Operation{OperandConstraint::createExact(2u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/Bulk.cc b/runtime/onert/core/src/ir/operation/Bulk.cc new file mode 100644 index 000000000..4b96c9d94 --- /dev/null +++ b/runtime/onert/core/src/ir/operation/Bulk.cc @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/operation/Bulk.h" +#include "ir/OperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace operation +{ +void Bulk::accept(OperationVisitor &v) const { v.visit(*this); } + +Bulk::Bulk(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, + const Bulk::Param ¶m) + : Operation{OperandConstraint::createAny(), inputs, outputs}, _param{param} +{ +} + +} // namespace operation +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/operation/Comparison.cc b/runtime/onert/core/src/ir/operation/Comparison.cc index 2f6775411..33365657c 100644 --- a/runtime/onert/core/src/ir/operation/Comparison.cc +++ b/runtime/onert/core/src/ir/operation/Comparison.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Comparison.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void Comparison::accept(OperationVisitor &v) const { v.visit(*this); } Comparison::Comparison(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/Concat.cc b/runtime/onert/core/src/ir/operation/Concat.cc index 608bc29a6..3a21e36f2 100644 --- a/runtime/onert/core/src/ir/operation/Concat.cc +++ b/runtime/onert/core/src/ir/operation/Concat.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Concat.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void Concat::accept(OperationVisitor &v) const { v.visit(*this); } Concat::Concat(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createAtLeast(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createAtLeast(1u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/Conv2D.cc b/runtime/onert/core/src/ir/operation/Conv2D.cc index 3a2e1d1fe..d615ae416 100644 --- a/runtime/onert/core/src/ir/operation/Conv2D.cc +++ b/runtime/onert/core/src/ir/operation/Conv2D.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Conv2D.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void Conv2D::accept(OperationVisitor &v) const { v.visit(*this); } Conv2D::Conv2D(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/ConvertFp16ToFp32.cc b/runtime/onert/core/src/ir/operation/ConvertFp16ToFp32.cc index 676e039fa..365745ea8 100644 --- a/runtime/onert/core/src/ir/operation/ConvertFp16ToFp32.cc +++ b/runtime/onert/core/src/ir/operation/ConvertFp16ToFp32.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/ConvertFp16ToFp32.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void ConvertFp16ToFp32::accept(OperationVisitor &v) const { v.visit(*this); } ConvertFp16ToFp32::ConvertFp16ToFp32(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(1u), inputs, outputs} + : Operation{OperandConstraint::createExact(1u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/ConvertFp32ToFp16.cc b/runtime/onert/core/src/ir/operation/ConvertFp32ToFp16.cc index bcfcbfc04..d4fc7031c 100644 --- a/runtime/onert/core/src/ir/operation/ConvertFp32ToFp16.cc +++ b/runtime/onert/core/src/ir/operation/ConvertFp32ToFp16.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/ConvertFp32ToFp16.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void ConvertFp32ToFp16::accept(OperationVisitor &v) const { v.visit(*this); } ConvertFp32ToFp16::ConvertFp32ToFp16(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(1u), inputs, outputs} + : Operation{OperandConstraint::createExact(1u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/Custom.cc b/runtime/onert/core/src/ir/operation/Custom.cc index 25c53e1ba..06c84f81a 100644 --- a/runtime/onert/core/src/ir/operation/Custom.cc +++ b/runtime/onert/core/src/ir/operation/Custom.cc @@ -29,7 +29,7 @@ void Custom::accept(OperationVisitor &v) const { v.visit(*this); } Custom::Custom(OperandConstraint input_constr, const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, std::string id, const Userdata &userdata) - : Operation{input_constr, inputs, outputs}, _id(std::move(id)), _userdata(userdata) + : Operation{input_constr, inputs, outputs}, _id(std::move(id)), _userdata(userdata) { } diff --git a/runtime/onert/core/src/ir/operation/DepthToSpace.cc b/runtime/onert/core/src/ir/operation/DepthToSpace.cc index f2d6c7c1b..e3edea777 100644 --- a/runtime/onert/core/src/ir/operation/DepthToSpace.cc +++ b/runtime/onert/core/src/ir/operation/DepthToSpace.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/DepthToSpace.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void DepthToSpace::accept(OperationVisitor &v) const { v.visit(*this); } DepthToSpace::DepthToSpace(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/DepthwiseConv2D.cc b/runtime/onert/core/src/ir/operation/DepthwiseConv2D.cc index d587a5591..0e7137306 100644 --- a/runtime/onert/core/src/ir/operation/DepthwiseConv2D.cc +++ b/runtime/onert/core/src/ir/operation/DepthwiseConv2D.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/DepthwiseConv2D.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void DepthwiseConv2D::accept(OperationVisitor &v) const { v.visit(*this); } DepthwiseConv2D::DepthwiseConv2D(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/DetectionPostProcess.cc b/runtime/onert/core/src/ir/operation/DetectionPostProcess.cc new file mode 100644 index 000000000..cd708796d --- /dev/null +++ b/runtime/onert/core/src/ir/operation/DetectionPostProcess.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/operation/DetectionPostProcess.h" +#include "ir/OperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace operation +{ + +DetectionPostProcess::DetectionPostProcess(const OperandIndexSequence &inputs, + const OperandIndexSequence &outputs, const Param ¶m) + : Operation(OperandConstraint::createExact(3u), inputs, outputs), _param(param) +{ +} + +void DetectionPostProcess::accept(OperationVisitor &v) const { v.visit(*this); } + +} // namespace operation +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/operation/Einsum.cc b/runtime/onert/core/src/ir/operation/Einsum.cc index 3c1473aaa..b50f070e7 100644 --- a/runtime/onert/core/src/ir/operation/Einsum.cc +++ b/runtime/onert/core/src/ir/operation/Einsum.cc @@ -28,7 +28,7 @@ void Einsum::accept(OperationVisitor &v) const { v.visit(*this); } Einsum::Einsum(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createAtLeast(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createAtLeast(1u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/ElementwiseActivation.cc b/runtime/onert/core/src/ir/operation/ElementwiseActivation.cc index f6718b656..e83c26e28 100644 --- a/runtime/onert/core/src/ir/operation/ElementwiseActivation.cc +++ b/runtime/onert/core/src/ir/operation/ElementwiseActivation.cc @@ -15,12 +15,10 @@ */ #include "ir/operation/ElementwiseActivation.h" +#include "ir/OperationVisitor.h" -#include <cassert> #include <unordered_map> -#include "ir/OperationVisitor.h" - namespace onert { namespace ir @@ -33,13 +31,14 @@ void ElementwiseActivation::accept(OperationVisitor &v) const { v.visit(*this); ElementwiseActivation::ElementwiseActivation(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} { if (param.op_type == Type::LOGISTIC) { - assert(param.alpha == 0.0f && param.beta == 0.0f && "Logistic will be supported only as " - "sigmoid function(L=1, k=1, x0=0). So, do " - "not use alpha and beta"); + assert(param.alpha == 0.0f && param.beta == 0.0f && + "Logistic will be supported only as " + "sigmoid function(L=1, k=1, x0=0). So, do " + "not use alpha and beta"); } else if (param.op_type == Type::RELU) { @@ -47,9 +46,10 @@ ElementwiseActivation::ElementwiseActivation(const OperandIndexSequence &inputs, } else if (param.op_type == Type::TANH) { - assert(param.alpha == 1.0f && param.beta == 1.0f && "f(x) = alpha * tanh(beta * x), Tanh is " - "supported only the values of alpha and " - "beta are 1.f"); + assert(param.alpha == 1.0f && param.beta == 1.0f && + "f(x) = alpha * tanh(beta * x), Tanh is " + "supported only the values of alpha and " + "beta are 1.f"); } } @@ -57,11 +57,11 @@ std::string ElementwiseActivation::name() const { using ElementwiseActivationType = onert::ir::operation::ElementwiseActivation::Type; static const std::unordered_map<Type, std::string> name_map{ - {ElementwiseActivationType::ELU, "ELU"}, - {ElementwiseActivationType::LOGISTIC, "Logistic"}, - {ElementwiseActivationType::RELU, "ReLU"}, - {ElementwiseActivationType::TANH, "Tanh"}, - {ElementwiseActivationType::LEAKY_RELU, "LeakyRelu"}}; + {ElementwiseActivationType::ELU, "ELU"}, + {ElementwiseActivationType::LOGISTIC, "Logistic"}, + {ElementwiseActivationType::RELU, "ReLU"}, + {ElementwiseActivationType::TANH, "Tanh"}, + {ElementwiseActivationType::LEAKY_RELU, "LeakyRelu"}}; return name_map.at(_param.op_type); } diff --git a/runtime/onert/core/src/ir/operation/ElementwiseBinary.cc b/runtime/onert/core/src/ir/operation/ElementwiseBinary.cc index 3287fc0a3..d445171fb 100644 --- a/runtime/onert/core/src/ir/operation/ElementwiseBinary.cc +++ b/runtime/onert/core/src/ir/operation/ElementwiseBinary.cc @@ -15,12 +15,10 @@ */ #include "ir/operation/ElementwiseBinary.h" +#include "ir/OperationVisitor.h" -#include <cassert> #include <unordered_map> -#include "ir/OperationVisitor.h" - namespace onert { namespace ir @@ -32,7 +30,7 @@ void ElementwiseBinary::accept(OperationVisitor &v) const { v.visit(*this); } ElementwiseBinary::ElementwiseBinary(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} { } @@ -40,10 +38,12 @@ std::string ElementwiseBinary::name() const { using ElementwiseBinaryType = onert::ir::operation::ElementwiseBinary::ElementwiseBinaryType; static const std::unordered_map<ElementwiseBinaryType, std::string> name_map{ - {ElementwiseBinaryType::LOGICAL_AND, std::string{"LogicalAnd"}}, - {ElementwiseBinaryType::LOGICAL_OR, std::string{"LogicalOr"}}, - {ElementwiseBinaryType::MAX, std::string{"Max"}}, - {ElementwiseBinaryType::MIN, std::string{"Min"}}}; + {ElementwiseBinaryType::FLOOR_DIV, std::string{"FloorDiv"}}, + {ElementwiseBinaryType::FLOOR_MOD, std::string{"FloorMod"}}, + {ElementwiseBinaryType::LOGICAL_AND, std::string{"LogicalAnd"}}, + {ElementwiseBinaryType::LOGICAL_OR, std::string{"LogicalOr"}}, + {ElementwiseBinaryType::MAX, std::string{"Max"}}, + {ElementwiseBinaryType::MIN, std::string{"Min"}}}; return name_map.at(_param.op_type); } diff --git a/runtime/onert/core/src/ir/operation/ElementwiseUnary.cc b/runtime/onert/core/src/ir/operation/ElementwiseUnary.cc index 7dfcd4a98..fd463e0fe 100644 --- a/runtime/onert/core/src/ir/operation/ElementwiseUnary.cc +++ b/runtime/onert/core/src/ir/operation/ElementwiseUnary.cc @@ -15,12 +15,10 @@ */ #include "ir/operation/ElementwiseUnary.h" +#include "ir/OperationVisitor.h" -#include <cassert> #include <unordered_map> -#include "ir/OperationVisitor.h" - namespace onert { namespace ir @@ -32,7 +30,9 @@ void ElementwiseUnary::accept(OperationVisitor &v) const { v.visit(*this); } ElementwiseUnary::ElementwiseUnary(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(1u), inputs, outputs, + OperandConstraint::createExact(1u)}, + _param{param} { } @@ -40,23 +40,23 @@ std::string ElementwiseUnary::name() const { using ElementwiseUnaryType = onert::ir::operation::ElementwiseUnary::Type; static const std::unordered_map<ElementwiseUnaryType, std::string> name_map{ - {ElementwiseUnaryType::ABS, std::string{"Abs"}}, - {ElementwiseUnaryType::CAST, std::string{"Cast"}}, - {ElementwiseUnaryType::COS, std::string{"Cos"}}, - {ElementwiseUnaryType::DEQUANTIZE, std::string{"Dequantize"}}, - {ElementwiseUnaryType::ERF, std::string{"Erf"}}, - {ElementwiseUnaryType::EXP, std::string{"Exp"}}, - {ElementwiseUnaryType::FLOOR, std::string{"Floor"}}, - {ElementwiseUnaryType::LOG, std::string{"Log"}}, - {ElementwiseUnaryType::LOGICAL_NOT, std::string{"LogicalNot"}}, - {ElementwiseUnaryType::NEG, std::string{"Neg"}}, - {ElementwiseUnaryType::QUANTIZE, std::string{"Quantize"}}, - {ElementwiseUnaryType::ROUND, std::string{"Round"}}, - {ElementwiseUnaryType::RSQRT, std::string{"RSqrt"}}, - {ElementwiseUnaryType::SIN, std::string{"Sin"}}, - {ElementwiseUnaryType::SQRT, std::string{"Sqrt"}}, - {ElementwiseUnaryType::SQURE, std::string{"Squre"}}, - {ElementwiseUnaryType::ZEROS_LIKE, std::string{"ZerosLike"}}}; + {ElementwiseUnaryType::ABS, std::string{"Abs"}}, + {ElementwiseUnaryType::CAST, std::string{"Cast"}}, + {ElementwiseUnaryType::COS, std::string{"Cos"}}, + {ElementwiseUnaryType::DEQUANTIZE, std::string{"Dequantize"}}, + {ElementwiseUnaryType::ERF, std::string{"Erf"}}, + {ElementwiseUnaryType::EXP, std::string{"Exp"}}, + {ElementwiseUnaryType::FLOOR, std::string{"Floor"}}, + {ElementwiseUnaryType::LOG, std::string{"Log"}}, + {ElementwiseUnaryType::LOGICAL_NOT, std::string{"LogicalNot"}}, + {ElementwiseUnaryType::NEG, std::string{"Neg"}}, + {ElementwiseUnaryType::QUANTIZE, std::string{"Quantize"}}, + {ElementwiseUnaryType::ROUND, std::string{"Round"}}, + {ElementwiseUnaryType::RSQRT, std::string{"RSqrt"}}, + {ElementwiseUnaryType::SIN, std::string{"Sin"}}, + {ElementwiseUnaryType::SQRT, std::string{"Sqrt"}}, + {ElementwiseUnaryType::SQUARE, std::string{"Square"}}, + {ElementwiseUnaryType::ZEROS_LIKE, std::string{"ZerosLike"}}}; return name_map.at(_param.op_type); } diff --git a/runtime/onert/core/src/ir/operation/EmbeddingLookup.cc b/runtime/onert/core/src/ir/operation/EmbeddingLookup.cc index b300b004e..66b80b2c5 100644 --- a/runtime/onert/core/src/ir/operation/EmbeddingLookup.cc +++ b/runtime/onert/core/src/ir/operation/EmbeddingLookup.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/EmbeddingLookup.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void EmbeddingLookup::accept(OperationVisitor &v) const { v.visit(*this); } EmbeddingLookup::EmbeddingLookup(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(2u), inputs, outputs} + : Operation{OperandConstraint::createExact(2u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/ExpandDims.cc b/runtime/onert/core/src/ir/operation/ExpandDims.cc index 3f555bd23..e421bc383 100644 --- a/runtime/onert/core/src/ir/operation/ExpandDims.cc +++ b/runtime/onert/core/src/ir/operation/ExpandDims.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/ExpandDims.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -30,7 +27,7 @@ namespace operation void ExpandDims::accept(OperationVisitor &v) const { v.visit(*this); } ExpandDims::ExpandDims(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(2u), inputs, outputs} + : Operation{OperandConstraint::createExact(2u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/Fill.cc b/runtime/onert/core/src/ir/operation/Fill.cc index c44f45aab..60355c609 100644 --- a/runtime/onert/core/src/ir/operation/Fill.cc +++ b/runtime/onert/core/src/ir/operation/Fill.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Fill.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -30,7 +27,7 @@ namespace operation void Fill::accept(OperationVisitor &v) const { v.visit(*this); } Fill::Fill(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(1u), inputs, outputs} + : Operation{OperandConstraint::createExact(2u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/FullyConnected.cc b/runtime/onert/core/src/ir/operation/FullyConnected.cc index 118ae554a..3533df097 100644 --- a/runtime/onert/core/src/ir/operation/FullyConnected.cc +++ b/runtime/onert/core/src/ir/operation/FullyConnected.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/FullyConnected.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void FullyConnected::accept(OperationVisitor &v) const { v.visit(*this); } FullyConnected::FullyConnected(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createInRange(2u, 3u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/FusedBatchNorm.cc b/runtime/onert/core/src/ir/operation/FusedBatchNorm.cc index 7b9301ea6..b5679f308 100644 --- a/runtime/onert/core/src/ir/operation/FusedBatchNorm.cc +++ b/runtime/onert/core/src/ir/operation/FusedBatchNorm.cc @@ -28,7 +28,7 @@ void FusedBatchNorm::accept(OperationVisitor &v) const { v.visit(*this); } FusedBatchNorm::FusedBatchNorm(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createAtLeast(5u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createAtLeast(5u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/Gather.cc b/runtime/onert/core/src/ir/operation/Gather.cc index 11d46e75b..e0c4630a0 100644 --- a/runtime/onert/core/src/ir/operation/Gather.cc +++ b/runtime/onert/core/src/ir/operation/Gather.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Gather.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void Gather::accept(OperationVisitor &v) const { v.visit(*this); } Gather::Gather(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/HashtableLookup.cc b/runtime/onert/core/src/ir/operation/HashtableLookup.cc index e9a7a82ff..5d1589cd1 100644 --- a/runtime/onert/core/src/ir/operation/HashtableLookup.cc +++ b/runtime/onert/core/src/ir/operation/HashtableLookup.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/HashtableLookup.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void HashtableLookup::accept(OperationVisitor &v) const { v.visit(*this); } HashtableLookup::HashtableLookup(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(3u), inputs, outputs} + : Operation{OperandConstraint::createExact(3u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/If.cc b/runtime/onert/core/src/ir/operation/If.cc index 599751dfd..380c87dbe 100644 --- a/runtime/onert/core/src/ir/operation/If.cc +++ b/runtime/onert/core/src/ir/operation/If.cc @@ -24,7 +24,7 @@ namespace operation { void If::accept(OperationVisitor &v) const { v.visit(*this); } If::If(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createAny(), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createAny(), inputs, outputs}, _param{param} { } } // namespace operation diff --git a/runtime/onert/core/src/ir/operation/InstanceNorm.cc b/runtime/onert/core/src/ir/operation/InstanceNorm.cc index 2334560ef..9fb55383e 100644 --- a/runtime/onert/core/src/ir/operation/InstanceNorm.cc +++ b/runtime/onert/core/src/ir/operation/InstanceNorm.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/InstanceNorm.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void InstanceNorm::accept(OperationVisitor &v) const { v.visit(*this); } InstanceNorm::InstanceNorm(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/L2Normalization.cc b/runtime/onert/core/src/ir/operation/L2Normalization.cc index 9a7d3eb61..6725df596 100644 --- a/runtime/onert/core/src/ir/operation/L2Normalization.cc +++ b/runtime/onert/core/src/ir/operation/L2Normalization.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/L2Normalization.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void L2Normalization::accept(OperationVisitor &v) const { v.visit(*this); } L2Normalization::L2Normalization(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(1u), inputs, outputs} + : Operation{OperandConstraint::createExact(1u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/LSTM.cc b/runtime/onert/core/src/ir/operation/LSTM.cc index 30a865326..06e66158b 100644 --- a/runtime/onert/core/src/ir/operation/LSTM.cc +++ b/runtime/onert/core/src/ir/operation/LSTM.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/LSTM.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,8 +28,16 @@ void LSTM::accept(OperationVisitor &v) const { v.visit(*this); } LSTM::LSTM(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(23u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createInRange(20u, 24u), inputs, outputs}, _param{param} +{ +} + +std::string LSTM::name() const { + if (getOutputs().at(Output::SCRATCH_BUFFER).undefined()) + return std::string{"UnidirectionalSequenceLSTM"}; + else + return Operation::name(); } } // namespace operation diff --git a/runtime/onert/core/src/ir/operation/LocalResponseNormalization.cc b/runtime/onert/core/src/ir/operation/LocalResponseNormalization.cc index 1ae97c142..73fca9938 100644 --- a/runtime/onert/core/src/ir/operation/LocalResponseNormalization.cc +++ b/runtime/onert/core/src/ir/operation/LocalResponseNormalization.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/LocalResponseNormalization.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -32,7 +29,7 @@ void LocalResponseNormalization::accept(OperationVisitor &v) const { v.visit(*th LocalResponseNormalization::LocalResponseNormalization(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/LogSoftmax.cc b/runtime/onert/core/src/ir/operation/LogSoftmax.cc index 73c6580ec..d580e63e1 100644 --- a/runtime/onert/core/src/ir/operation/LogSoftmax.cc +++ b/runtime/onert/core/src/ir/operation/LogSoftmax.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/LogSoftmax.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void LogSoftmax::accept(OperationVisitor &v) const { v.visit(*this); } LogSoftmax::LogSoftmax(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/Loss.cc b/runtime/onert/core/src/ir/operation/Loss.cc new file mode 100644 index 000000000..2a0d6c4c8 --- /dev/null +++ b/runtime/onert/core/src/ir/operation/Loss.cc @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/operation/Loss.h" +#include "ir/OperationVisitor.h" + +#include <unordered_map> + +namespace onert +{ +namespace ir +{ +namespace operation +{ + +void Loss::accept(OperationVisitor &v) const { v.visit(*this); } + +Loss::Loss(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) + : Operation{OperandConstraint::createAtLeast(2u), inputs, outputs} +{ + assert(inputs.size() == 2); +} + +} // namespace operation +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/operation/MatrixBandPart.cc b/runtime/onert/core/src/ir/operation/MatrixBandPart.cc index bac31f13e..e52bddc1f 100644 --- a/runtime/onert/core/src/ir/operation/MatrixBandPart.cc +++ b/runtime/onert/core/src/ir/operation/MatrixBandPart.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/MatrixBandPart.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void MatrixBandPart::accept(OperationVisitor &v) const { v.visit(*this); } MatrixBandPart::MatrixBandPart(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(3u), inputs, outputs} + : Operation{OperandConstraint::createExact(3u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/OneHot.cc b/runtime/onert/core/src/ir/operation/OneHot.cc index 22935e7d6..90898f1ed 100644 --- a/runtime/onert/core/src/ir/operation/OneHot.cc +++ b/runtime/onert/core/src/ir/operation/OneHot.cc @@ -28,7 +28,7 @@ void OneHot::accept(OperationVisitor &v) const { v.visit(*this); } OneHot::OneHot(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(4u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(4u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/PReLU.cc b/runtime/onert/core/src/ir/operation/PReLU.cc index a2e37e0ad..87bd12e60 100644 --- a/runtime/onert/core/src/ir/operation/PReLU.cc +++ b/runtime/onert/core/src/ir/operation/PReLU.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/PReLU.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -30,7 +27,7 @@ namespace operation void PReLU::accept(OperationVisitor &v) const { v.visit(*this); } PReLU::PReLU(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(2u), inputs, outputs} + : Operation{OperandConstraint::createExact(2u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/Pack.cc b/runtime/onert/core/src/ir/operation/Pack.cc index f0908a2c6..00feadfb0 100644 --- a/runtime/onert/core/src/ir/operation/Pack.cc +++ b/runtime/onert/core/src/ir/operation/Pack.cc @@ -25,7 +25,7 @@ namespace operation void Pack::accept(OperationVisitor &v) const { v.visit(*this); } Pack::Pack(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createAtLeast(3u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createAtLeast(1u), inputs, outputs}, _param{param} { } } // namespace operation diff --git a/runtime/onert/core/src/ir/operation/Pad.cc b/runtime/onert/core/src/ir/operation/Pad.cc index 0c56e92e3..a3f2d9752 100644 --- a/runtime/onert/core/src/ir/operation/Pad.cc +++ b/runtime/onert/core/src/ir/operation/Pad.cc @@ -30,7 +30,7 @@ void Pad::accept(OperationVisitor &v) const { v.visit(*this); } // PAD: 2 inputs // PADV2: 3 inputs Pad::Pad(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createInRange(2u, 3u), inputs, outputs} + : Operation{OperandConstraint::createInRange(2u, 3u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/Permute.cc b/runtime/onert/core/src/ir/operation/Permute.cc index eefb6c542..813fbaf30 100644 --- a/runtime/onert/core/src/ir/operation/Permute.cc +++ b/runtime/onert/core/src/ir/operation/Permute.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Permute.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -30,7 +27,7 @@ namespace operation void Permute::accept(OperationVisitor &v) const { v.visit(*this); } Permute::Permute(const OperandIndex &input, const OperandIndex &output, Type type) - : Operation{OperandConstraint::createExact(1u)}, _type{type} + : Operation{OperandConstraint::createExact(1u)}, _type{type} { setInputs({input}); setOutputs({output}); diff --git a/runtime/onert/core/src/ir/operation/Pool2D.cc b/runtime/onert/core/src/ir/operation/Pool2D.cc index 761d14c3d..e32b876e6 100644 --- a/runtime/onert/core/src/ir/operation/Pool2D.cc +++ b/runtime/onert/core/src/ir/operation/Pool2D.cc @@ -15,12 +15,10 @@ */ #include "ir/operation/Pool2D.h" +#include "ir/OperationVisitor.h" -#include <cassert> #include <unordered_map> -#include "ir/OperationVisitor.h" - namespace onert { namespace ir @@ -32,7 +30,7 @@ void Pool2D::accept(OperationVisitor &v) const { v.visit(*this); } Pool2D::Pool2D(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} { } @@ -40,9 +38,9 @@ std::string Pool2D::name() const { using PoolType = onert::ir::operation::Pool2D::PoolType; static const std::unordered_map<PoolType, std::string> name_map{ - {PoolType::AVG, "Avg" + std::string{toString(opcode())}}, - {PoolType::L2, "L2" + std::string{toString(opcode())}}, - {PoolType::MAX, "Max" + std::string{toString(opcode())}}}; + {PoolType::AVG, "Avg" + std::string{toString(opcode())}}, + {PoolType::L2, "L2" + std::string{toString(opcode())}}, + {PoolType::MAX, "Max" + std::string{toString(opcode())}}}; return name_map.at(_param.op_type); } diff --git a/runtime/onert/core/src/ir/operation/Pow.cc b/runtime/onert/core/src/ir/operation/Pow.cc index 940b1391a..f7c159a12 100644 --- a/runtime/onert/core/src/ir/operation/Pow.cc +++ b/runtime/onert/core/src/ir/operation/Pow.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Pow.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -30,7 +27,7 @@ namespace operation void Pow::accept(OperationVisitor &v) const { v.visit(*this); } Pow::Pow(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(2u), inputs, outputs} + : Operation{OperandConstraint::createExact(2u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/RNN.cc b/runtime/onert/core/src/ir/operation/RNN.cc index 298c5e745..988a50669 100644 --- a/runtime/onert/core/src/ir/operation/RNN.cc +++ b/runtime/onert/core/src/ir/operation/RNN.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/RNN.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void RNN::accept(OperationVisitor &v) const { v.visit(*this); } RNN::RNN(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(5u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(5u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/Range.cc b/runtime/onert/core/src/ir/operation/Range.cc index 96ab04c1b..8ced92a0b 100644 --- a/runtime/onert/core/src/ir/operation/Range.cc +++ b/runtime/onert/core/src/ir/operation/Range.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Range.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -30,7 +27,7 @@ namespace operation void Range::accept(OperationVisitor &v) const { v.visit(*this); } Range::Range(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(3u), inputs, outputs} + : Operation{OperandConstraint::createExact(3u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/Rank.cc b/runtime/onert/core/src/ir/operation/Rank.cc index c357e9018..40797bf29 100644 --- a/runtime/onert/core/src/ir/operation/Rank.cc +++ b/runtime/onert/core/src/ir/operation/Rank.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Rank.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -30,7 +27,7 @@ namespace operation void Rank::accept(OperationVisitor &v) const { v.visit(*this); } Rank::Rank(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(1u), inputs, outputs} + : Operation{OperandConstraint::createExact(1u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/Reduce.cc b/runtime/onert/core/src/ir/operation/Reduce.cc index d6a1d953c..8da1940fa 100644 --- a/runtime/onert/core/src/ir/operation/Reduce.cc +++ b/runtime/onert/core/src/ir/operation/Reduce.cc @@ -15,12 +15,10 @@ */ #include "ir/operation/Reduce.h" +#include "ir/OperationVisitor.h" -#include <cassert> #include <unordered_map> -#include "ir/OperationVisitor.h" - namespace onert { namespace ir @@ -32,7 +30,7 @@ void Reduce::accept(OperationVisitor &v) const { v.visit(*this); } Reduce::Reduce(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} { } @@ -40,13 +38,13 @@ std::string Reduce::name() const { using ReduceType = onert::ir::operation::Reduce::ReduceType; static const std::unordered_map<ReduceType, std::string> name_map{ - {ReduceType::ALL, std::string{toString(opcode())} + "All"}, - {ReduceType::ANY, std::string{toString(opcode())} + "Any"}, - {ReduceType::MAX, std::string{toString(opcode())} + "Max"}, - {ReduceType::MEAN, std::string{toString(opcode())} + "Mean"}, - {ReduceType::MIN, std::string{toString(opcode())} + "Min"}, - {ReduceType::PROD, std::string{toString(opcode())} + "Prod"}, - {ReduceType::SUM, std::string{toString(opcode())} + "SUM"}}; + {ReduceType::ALL, std::string{toString(opcode())} + "All"}, + {ReduceType::ANY, std::string{toString(opcode())} + "Any"}, + {ReduceType::MAX, std::string{toString(opcode())} + "Max"}, + {ReduceType::MEAN, std::string{toString(opcode())} + "Mean"}, + {ReduceType::MIN, std::string{toString(opcode())} + "Min"}, + {ReduceType::PROD, std::string{toString(opcode())} + "Prod"}, + {ReduceType::SUM, std::string{toString(opcode())} + "SUM"}}; return name_map.at(_param.reduce_type); // return std::string(toString(opcode())) + reduce_type_str_map.at(_param.reduce_type); } diff --git a/runtime/onert/core/src/ir/operation/Reshape.cc b/runtime/onert/core/src/ir/operation/Reshape.cc index 92aa89ac6..0ed4affa1 100644 --- a/runtime/onert/core/src/ir/operation/Reshape.cc +++ b/runtime/onert/core/src/ir/operation/Reshape.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Reshape.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void Reshape::accept(OperationVisitor &v) const { v.visit(*this); } Reshape::Reshape(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param(param) + : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param(param) { } diff --git a/runtime/onert/core/src/ir/operation/ResizeBilinear.cc b/runtime/onert/core/src/ir/operation/ResizeBilinear.cc index d0d89f45f..7d256f447 100644 --- a/runtime/onert/core/src/ir/operation/ResizeBilinear.cc +++ b/runtime/onert/core/src/ir/operation/ResizeBilinear.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/ResizeBilinear.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void ResizeBilinear::accept(OperationVisitor &v) const { v.visit(*this); } ResizeBilinear::ResizeBilinear(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createInRange(1u, 2u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/ResizeNearestNeighbor.cc b/runtime/onert/core/src/ir/operation/ResizeNearestNeighbor.cc index 9f17af97c..58be87b95 100644 --- a/runtime/onert/core/src/ir/operation/ResizeNearestNeighbor.cc +++ b/runtime/onert/core/src/ir/operation/ResizeNearestNeighbor.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/ResizeNearestNeighbor.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -32,7 +29,7 @@ void ResizeNearestNeighbor::accept(OperationVisitor &v) const { v.visit(*this); ResizeNearestNeighbor::ResizeNearestNeighbor(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createInRange(1u, 2u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/Reverse.cc b/runtime/onert/core/src/ir/operation/Reverse.cc index 4b3c1e1af..6c3746426 100644 --- a/runtime/onert/core/src/ir/operation/Reverse.cc +++ b/runtime/onert/core/src/ir/operation/Reverse.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Reverse.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -30,7 +27,7 @@ namespace operation void Reverse::accept(OperationVisitor &v) const { v.visit(*this); } Reverse::Reverse(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(2u), inputs, outputs} + : Operation{OperandConstraint::createExact(2u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/Select.cc b/runtime/onert/core/src/ir/operation/Select.cc index 1f22b5234..59684190c 100644 --- a/runtime/onert/core/src/ir/operation/Select.cc +++ b/runtime/onert/core/src/ir/operation/Select.cc @@ -28,7 +28,7 @@ namespace operation void Select::accept(OperationVisitor &v) const { v.visit(*this); } Select::Select(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(3u), inputs, outputs} + : Operation{OperandConstraint::createExact(3u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/Shape.cc b/runtime/onert/core/src/ir/operation/Shape.cc index 2a63d6dcf..f90924488 100644 --- a/runtime/onert/core/src/ir/operation/Shape.cc +++ b/runtime/onert/core/src/ir/operation/Shape.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Shape.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -30,7 +27,7 @@ namespace operation void Shape::accept(OperationVisitor &v) const { v.visit(*this); } Shape::Shape(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(1u), inputs, outputs} + : Operation{OperandConstraint::createExact(1u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/Slice.cc b/runtime/onert/core/src/ir/operation/Slice.cc index 888b563fb..1362c0f91 100644 --- a/runtime/onert/core/src/ir/operation/Slice.cc +++ b/runtime/onert/core/src/ir/operation/Slice.cc @@ -27,7 +27,7 @@ namespace operation void Slice::accept(OperationVisitor &v) const { v.visit(*this); } Slice::Slice(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(3u), inputs, outputs} + : Operation{OperandConstraint::createExact(3u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/Softmax.cc b/runtime/onert/core/src/ir/operation/Softmax.cc index 3f1aa0af1..c06c85309 100644 --- a/runtime/onert/core/src/ir/operation/Softmax.cc +++ b/runtime/onert/core/src/ir/operation/Softmax.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Softmax.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void Softmax::accept(OperationVisitor &v) const { v.visit(*this); } Softmax::Softmax(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/SpaceToBatchND.cc b/runtime/onert/core/src/ir/operation/SpaceToBatchND.cc index 53fab4fa9..94acccb0c 100644 --- a/runtime/onert/core/src/ir/operation/SpaceToBatchND.cc +++ b/runtime/onert/core/src/ir/operation/SpaceToBatchND.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/SpaceToBatchND.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void SpaceToBatchND::accept(OperationVisitor &v) const { v.visit(*this); } SpaceToBatchND::SpaceToBatchND(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(3u), inputs, outputs} + : Operation{OperandConstraint::createExact(3u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/SpaceToDepth.cc b/runtime/onert/core/src/ir/operation/SpaceToDepth.cc index d8a45aee5..08e7e5190 100644 --- a/runtime/onert/core/src/ir/operation/SpaceToDepth.cc +++ b/runtime/onert/core/src/ir/operation/SpaceToDepth.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/SpaceToDepth.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void SpaceToDepth::accept(OperationVisitor &v) const { v.visit(*this); } SpaceToDepth::SpaceToDepth(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/Split.cc b/runtime/onert/core/src/ir/operation/Split.cc index 244884e41..3e371188d 100644 --- a/runtime/onert/core/src/ir/operation/Split.cc +++ b/runtime/onert/core/src/ir/operation/Split.cc @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "ir/operation/Split.h" -#include <cassert> #include "ir/OperationVisitor.h" + namespace onert { namespace ir @@ -25,7 +26,7 @@ namespace operation void Split::accept(OperationVisitor &v) const { v.visit(*this); } Split::Split(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} { } } // namespace operation diff --git a/runtime/onert/core/src/ir/operation/SplitV.cc b/runtime/onert/core/src/ir/operation/SplitV.cc index e638c9ac9..be13f167e 100644 --- a/runtime/onert/core/src/ir/operation/SplitV.cc +++ b/runtime/onert/core/src/ir/operation/SplitV.cc @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "ir/operation/SplitV.h" -#include <cassert> #include "ir/OperationVisitor.h" + namespace onert { namespace ir @@ -25,7 +26,7 @@ namespace operation void SplitV::accept(OperationVisitor &v) const { v.visit(*this); } SplitV::SplitV(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param} { } } // namespace operation diff --git a/runtime/onert/core/src/ir/operation/SquaredDifference.cc b/runtime/onert/core/src/ir/operation/SquaredDifference.cc index 49e58aaf2..db93903c7 100644 --- a/runtime/onert/core/src/ir/operation/SquaredDifference.cc +++ b/runtime/onert/core/src/ir/operation/SquaredDifference.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/SquaredDifference.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void SquaredDifference::accept(OperationVisitor &v) const { v.visit(*this); } SquaredDifference::SquaredDifference(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(2u), inputs, outputs} + : Operation{OperandConstraint::createExact(2u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/Squeeze.cc b/runtime/onert/core/src/ir/operation/Squeeze.cc index 8cf928fb4..e059c4bee 100644 --- a/runtime/onert/core/src/ir/operation/Squeeze.cc +++ b/runtime/onert/core/src/ir/operation/Squeeze.cc @@ -28,7 +28,7 @@ void Squeeze::accept(OperationVisitor &v) const { v.visit(*this); } Squeeze::Squeeze(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param(param) + : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param(param) { } diff --git a/runtime/onert/core/src/ir/operation/StatelessRandomUniform.cc b/runtime/onert/core/src/ir/operation/StatelessRandomUniform.cc index cbb0ff251..94be0be86 100644 --- a/runtime/onert/core/src/ir/operation/StatelessRandomUniform.cc +++ b/runtime/onert/core/src/ir/operation/StatelessRandomUniform.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/StatelessRandomUniform.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -30,7 +27,7 @@ void StatelessRandomUniform::accept(OperationVisitor &v) const { v.visit(*this); StatelessRandomUniform::StatelessRandomUniform(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(2u), inputs, outputs} + : Operation{OperandConstraint::createExact(2u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/StridedSlice.cc b/runtime/onert/core/src/ir/operation/StridedSlice.cc index 2a7905995..a38282c93 100644 --- a/runtime/onert/core/src/ir/operation/StridedSlice.cc +++ b/runtime/onert/core/src/ir/operation/StridedSlice.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/StridedSlice.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void StridedSlice::accept(OperationVisitor &v) const { v.visit(*this); } StridedSlice::StridedSlice(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(4u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(4u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/Tile.cc b/runtime/onert/core/src/ir/operation/Tile.cc index 5ba3df2ad..51c1ff1dc 100644 --- a/runtime/onert/core/src/ir/operation/Tile.cc +++ b/runtime/onert/core/src/ir/operation/Tile.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Tile.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -30,7 +27,7 @@ namespace operation void Tile::accept(OperationVisitor &v) const { v.visit(*this); } Tile::Tile(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) - : Operation{OperandConstraint::createExact(2u), inputs, outputs} + : Operation{OperandConstraint::createExact(2u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/TopKV2.cc b/runtime/onert/core/src/ir/operation/TopKV2.cc index a5e6c6a85..e1723d180 100644 --- a/runtime/onert/core/src/ir/operation/TopKV2.cc +++ b/runtime/onert/core/src/ir/operation/TopKV2.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/TopKV2.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void TopKV2::accept(OperationVisitor &v) const { v.visit(*this); } TopKV2::TopKV2(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/Transpose.cc b/runtime/onert/core/src/ir/operation/Transpose.cc index 3a663fbce..dbc5ef2aa 100644 --- a/runtime/onert/core/src/ir/operation/Transpose.cc +++ b/runtime/onert/core/src/ir/operation/Transpose.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/Transpose.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -29,9 +26,8 @@ namespace operation void Transpose::accept(OperationVisitor &v) const { v.visit(*this); } -Transpose::Transpose(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, - const Param ¶m) - : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} +Transpose::Transpose(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs) + : Operation{OperandConstraint::createExact(2u), inputs, outputs} { } diff --git a/runtime/onert/core/src/ir/operation/TransposeConv.cc b/runtime/onert/core/src/ir/operation/TransposeConv.cc index 7f29ca44e..944cc365d 100644 --- a/runtime/onert/core/src/ir/operation/TransposeConv.cc +++ b/runtime/onert/core/src/ir/operation/TransposeConv.cc @@ -15,9 +15,6 @@ */ #include "ir/operation/TransposeConv.h" - -#include <cassert> - #include "ir/OperationVisitor.h" namespace onert @@ -31,7 +28,7 @@ void TransposeConv::accept(OperationVisitor &v) const { v.visit(*this); } TransposeConv::TransposeConv(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param} { } diff --git a/runtime/onert/core/src/ir/operation/Unpack.cc b/runtime/onert/core/src/ir/operation/Unpack.cc index 67aa54ab5..185eddce3 100644 --- a/runtime/onert/core/src/ir/operation/Unpack.cc +++ b/runtime/onert/core/src/ir/operation/Unpack.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "ir/operation/Unpack.h" #include "ir/OperationVisitor.h" @@ -25,7 +26,7 @@ namespace operation void Unpack::accept(OperationVisitor &v) const { v.visit(*this); } Unpack::Unpack(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param} { } } // namespace operation diff --git a/runtime/onert/core/src/ir/operation/While.cc b/runtime/onert/core/src/ir/operation/While.cc index 2505c60e3..f35996b07 100644 --- a/runtime/onert/core/src/ir/operation/While.cc +++ b/runtime/onert/core/src/ir/operation/While.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "ir/operation/While.h" #include "ir/OperationVisitor.h" @@ -25,7 +26,7 @@ namespace operation void While::accept(OperationVisitor &v) const { v.visit(*this); } While::While(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m) - : Operation{OperandConstraint::createAny(), inputs, outputs}, _param{param} + : Operation{OperandConstraint::createAny(), inputs, outputs}, _param{param} { } } // namespace operation diff --git a/runtime/onert/core/src/ir/train/LossCode.cc b/runtime/onert/core/src/ir/train/LossCode.cc new file mode 100644 index 000000000..eccae8cd7 --- /dev/null +++ b/runtime/onert/core/src/ir/train/LossCode.cc @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/LossCode.h" + +#include <unordered_map> + +namespace onert +{ +namespace ir +{ +namespace train +{ + +std::string toString(LossCode code) +{ + static const std::unordered_map<LossCode, const char *> map{ + {LossCode::Undefined, "Undefined"}, + {LossCode::MeanSquaredError, "MeanSquaredError"}, + {LossCode::CategoricalCrossentropy, "CategoricalCrossentropy"}}; + return map.at(code); +} + +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/OptimizerCode.cc b/runtime/onert/core/src/ir/train/OptimizerCode.cc new file mode 100644 index 000000000..4ab689085 --- /dev/null +++ b/runtime/onert/core/src/ir/train/OptimizerCode.cc @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/OptimizerCode.h" + +#include <unordered_map> + +namespace onert +{ +namespace ir +{ +namespace train +{ + +std::string toString(OptimizerCode code) +{ + static const std::unordered_map<OptimizerCode, const char *> map{ + {OptimizerCode::Undefined, "Undefined"}, + {OptimizerCode::SGD, "SGD"}, + {OptimizerCode::Adam, "Adam"}}; + return map.at(code); +} + +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/TrainableGraph.cc b/runtime/onert/core/src/ir/train/TrainableGraph.cc new file mode 100644 index 000000000..5ecdcc2cb --- /dev/null +++ b/runtime/onert/core/src/ir/train/TrainableGraph.cc @@ -0,0 +1,337 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/TrainableGraph.h" + +#include "ir/OperandIndexMap.h" +#include "util/Utils.h" +#include "util/Set.h" +#include "../verifier/Verifier.h" + +#include <algorithm> +#include <set> +#include <map> +#include <misc/polymorphic_downcast.h> + +namespace onert +{ +namespace ir +{ +namespace train +{ + +TrainableGraph::TrainableGraph() : _graph{} {} + +TrainableGraph::TrainableGraph(const TrainableGraph &tgraph) + : _graph{tgraph._graph}, _backward_operands{tgraph._backward_operands}, + _training_defuses{tgraph._training_defuses}, _losses{tgraph._losses} +{ + tgraph.operations().iterate( + [&](const onert::ir::OperationIndex &index, const onert::ir::IOperation &op) { + replaceOperation(index, dynamic_cast<const ITrainableOperation &>(op).clone()); + }); +} + +TrainableGraph::TrainableGraph(const Graph &graph) : _graph{graph} {} + +OperandIndex TrainableGraph::addOperand(const Shape &shape, const TypeInfo &type) +{ + return _graph.addOperand(shape, type); +} + +OperandIndex TrainableGraph::addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand) +{ + return _graph.addOperand(index, std::move(operand)); +} + +OperationIndex TrainableGraph::addOperation(std::unique_ptr<ITrainableOperation> &&operation) +{ + return _graph.addOperation(std::move(operation)); +} + +OperationIndex TrainableGraph::replaceOperation(OperationIndex index, + std::unique_ptr<ITrainableOperation> &&operation) +{ + return _graph.replaceOperation(index, std::move(operation)); +} + +OperandIndex TrainableGraph::addBackwardOperand(OperandIndex index, + std::unique_ptr<Operand> &&bwd_operand) +{ + return _backward_operands.push(std::move(bwd_operand), index); +} + +IOIndex TrainableGraph::getInputIndex(const std::string &name) const +{ + return _graph.getInputIndex(name); +} + +IOIndex TrainableGraph::getOutputIndex(const std::string &name) const +{ + return _graph.getOutputIndex(name); +} + +void TrainableGraph::changeShape(const OperandIndex &index, const ir::Shape &new_shape) +{ + _graph.changeShape(index, new_shape); +} + +void TrainableGraph::changeBackwardShape(const OperandIndex &index, const ir::Shape &new_shape) +{ + assert(_backward_operands.exist(index)); + _backward_operands.at(index).info().shape(new_shape); +} + +void TrainableGraph::addInput(const OperandIndex &ind, const std::string &name) +{ + _graph.addInput(ind, name); +} + +void TrainableGraph::addOutput(const OperandIndex &ind, const std::string &name) +{ + _graph.addOutput(ind, name); +} + +void TrainableGraph::verify(void) const +{ + _graph.verify(); + + operations().iterate([](const onert::ir::OperationIndex &, const onert::ir::IOperation &op) { + try + { + UNUSED_RELEASE(dynamic_cast<const onert::ir::train::ITrainableOperation &>(op)); + } + catch (const std::bad_cast &) + { + throw std::runtime_error("TrainableGraph: " + op.name() + " is not a trainable operation"); + } + }); + + verifyTrainingUseDefs(); +} + +void TrainableGraph::removeOperand(const OperandIndex &ind) { _graph.removeOperand(ind); } + +void TrainableGraph::setLayout(Layout layout) { _graph.setLayout(layout); } + +const ITrainableOperation &TrainableGraph::operation(OperationIndex index) const +{ + // NOTE Virtual inherited objects cannot be static_casted. + return dynamic_cast<const ITrainableOperation &>(_graph.operations().at(index)); +} + +void TrainableGraph::enableBackward(const OperationIndex &index) +{ + auto op = dynamic_cast<ir::train::ITrainableOperation *>(&_graph.operations().at(index)); + assert(op); + op->enableBackward(); +} + +void TrainableGraph::disableBackward(const OperationIndex &index) +{ + auto &op = dynamic_cast<ir::train::ITrainableOperation &>(_graph.operations().at(index)); + op.disableBackward(); +} + +void TrainableGraph::setTrainingUseDefs(const UseDefChains &training_defuses) +{ + _training_defuses.clear(); + // TODO Replace this loop with `std::unordered_map::insert_range` since C++23 + for (const auto &defuse_chain : training_defuses) + { + _training_defuses.emplace(defuse_chain.first, defuse_chain.second); + } +} + +void TrainableGraph::validateTopologicalOrder(std::vector<ir::OperationIndex> order, + bool is_forward) const +{ + if (!is_forward) + std::reverse(order.begin(), order.end()); + + const std::string order_type = is_forward ? "forward" : "backward"; + + std::map<ir::OperationIndex, uint32_t> position; + for (uint32_t p = 0; p < order.size(); ++p) + { + auto index = order[p]; + // TODO: replace this with `std::map::contains` after C++20 + if (position.find(index) != position.end()) + throw std::runtime_error{"Invalid " + order_type + " topological order: duplicate node @" + + std::to_string(index.value())}; + + position[index] = p; + } + + operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) { + if (position.count(index) == 0) + return; + + uint32_t p = position[index]; + + for (const auto &output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + { + const auto &operand = operands().at(output); + for (const auto &use : operand.getUses()) + { + if (position.count(use) == 0) + continue; + + uint32_t q = position[use]; + if (p > q) + throw std::runtime_error{ + "Invalid " + order_type + " topological order: inversion between @" + + std::to_string(index.value()) + " and @" + std::to_string(use.value())}; + } + } + }); +} + +void TrainableGraph::validateForwardTopologicalOrder( + const std::vector<ir::OperationIndex> &order) const +{ + validateTopologicalOrder(order, true); +} + +void TrainableGraph::validateBackwardTopologicalOrder( + const std::vector<ir::OperationIndex> &order) const +{ + validateTopologicalOrder(order, false); +} + +void TrainableGraph::verifyTrainingUseDefs() const +{ + if (!verifier::DAGChecker().verify(_training_defuses)) + throw std::runtime_error{"The training def-uses is cyclic."}; + assert(verifier::EdgeChecker().verify(_training_defuses)); +} + +std::vector<ir::OperationIndex> TrainableGraph::topolSortOperations() const +{ + auto ret = _graph.topolSortOperations(); + validateForwardTopologicalOrder(ret); + + return ret; +} + +std::vector<ir::OperationIndex> TrainableGraph::btopolSortOperations() const +{ + std::vector<ir::OperationIndex> ret; + util::Set<ir::OperationIndex> unvisited; + ir::OperationIndex loss_idx; + operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) { + unvisited.add(index); + if (op.opcode() == ir::OpCode::Loss) + { + assert(!loss_idx.valid()); // Should be only one loss + loss_idx = index; + } + }); + + std::function<void(const ir::OperationIndex &, const ir::IOperation &)> dfs = + [&](const ir::OperationIndex &index, const ir::IOperation &op) -> void { + if (!unvisited.contains(index)) + return; + unvisited.remove(index); + + for (const auto &input : op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + { + const auto &operand = operands().at(input); + const auto &def = operand.getDef(); + if (!def.valid()) + continue; + dfs(def, operations().at(def)); + } + + ret.push_back(index); + }; + + dfs(loss_idx, operations().at(loss_idx)); + std::reverse(ret.begin(), ret.end()); + validateBackwardTopologicalOrder(ret); + + return ret; +} + +std::vector<ir::OperationIndex> TrainableGraph::essentialBackwardOrder() const +{ + auto backward_order = btopolSortOperations(); + // get rid of all nodes not reachable from a node with trainable parameters + backward_order = truncateBackwardOrder(backward_order, [&](const OperationIndex &index) { + return operation(index).isRequiredForBackward(); + }); + + return truncateBackwardOrder(backward_order); +} + +std::vector<ir::OperationIndex> TrainableGraph::truncateBackwardOrder( + std::vector<ir::OperationIndex> backward_order, + std::function<bool(const ir::OperationIndex &)> alive_cond) const +{ + auto forward_order = backward_order; + std::reverse(forward_order.begin(), forward_order.end()); + std::set<ir::OperationIndex> alive; + + for (const auto &index : forward_order) + { + if (alive_cond(index)) + alive.insert(index); + + // TODO: replace this with `std::set::contains` after C++20 + if (alive.find(index) != alive.end()) + { + const auto &op = operations().at(index); + for (const auto &output : op.getOutputs()) + { + const auto &operand = operands().at(output); + for (const auto &use : operand.getUses()) + alive.insert(use); + } + } + } + + // TODO: replace this with `std::erase_if(std::vector)` after C++20 + backward_order.erase( + std::remove_if(backward_order.begin(), backward_order.end(), + [&](const auto &index) { return alive.find(index) == alive.end(); }), + backward_order.end()); + return backward_order; +} + +std::vector<ir::OperationIndex> +TrainableGraph::truncateBackwardOrder(const std::vector<ir::OperationIndex> &backward_order) const +{ + return truncateBackwardOrder(backward_order, [&](const ir::OperationIndex &index) { + const auto &trainable_op = operation(index); + + return trainable_op.hasTrainableParameter(); + }); +} + +void TrainableGraph::addLoss(const OperandIndex &loss_ind, const IOIndex &pred_ioind) +{ + _losses.emplace(pred_ioind, loss_ind); +} + +OperandIndex TrainableGraph::getLossIndex(const IOIndex &pred_ioind) const +{ + auto itr = _losses.find(pred_ioind); + return (itr == _losses.end()) ? OperandIndex{} : itr->second; +} + +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/TrainableGraph.test.cc b/runtime/onert/core/src/ir/train/TrainableGraph.test.cc new file mode 100644 index 000000000..84df22890 --- /dev/null +++ b/runtime/onert/core/src/ir/train/TrainableGraph.test.cc @@ -0,0 +1,378 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/TrainableGraph.h" +#include "ir/train/operation/BinaryArithmetic.h" +#include "ir/train/operation/ElementwiseActivation.h" +#include "ir/train/operation/FullyConnected.h" +#include "ir/train/operation/Loss.h" +#include "ir/train/LossInfo.h" + +#include <gtest/gtest.h> + +using namespace onert::ir; + +OperationIndex addAddOperation(train::TrainableGraph &tgraph, const OperandIndexSequence inputs, + const OperandIndexSequence outputs) +{ + // Add "ADD" operation + operation::BinaryArithmetic::Param param; + param.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD; + param.activation = Activation::NONE; + auto add_op = operation::BinaryArithmetic(inputs, outputs, param); + return tgraph.addOperation(std::make_unique<train::operation::BinaryArithmetic>(add_op)); +} + +OperationIndex addElementwiseActivationOperation(train::TrainableGraph &tgraph, + const OperandIndexSequence inputs, + const OperandIndexSequence outputs) +{ + // Add "ElementwiseActivation" operation + operation::ElementwiseActivation::Param param; + auto ea_op = operation::ElementwiseActivation(inputs, outputs, param); + return tgraph.addOperation(std::make_unique<train::operation::ElementwiseActivation>(ea_op)); +} + +OperationIndex addFullyConnectedOperation(train::TrainableGraph &tgraph, + const OperandIndexSequence inputs, + const OperandIndexSequence outputs) +{ + // Add "FullyConnected" operation + operation::FullyConnected::Param param; + param.weights_format = FullyConnectedWeightsFormat::Default; + param.activation = Activation::NONE; + auto fc_op = operation::FullyConnected(inputs, outputs, param); + return tgraph.addOperation(std::make_unique<train::operation::FullyConnected>(fc_op)); +} + +OperationIndex addLossOperation(train::TrainableGraph &tgraph, const OperandIndexSequence inputs, + const OperandIndexSequence outputs) +{ + // Add "Loss" operation + auto loss_op = operation::Loss(inputs, outputs); + return tgraph.addOperation(std::make_unique<train::operation::Loss>(loss_op, train::LossInfo{})); +} + +TEST(TrainableGraph, topological_sort_linear) +{ + train::TrainableGraph tgraph; + + Shape shape{1, 2, 2, 1}; + TypeInfo type{DataType::FLOAT32}; + + /* + (input) ⎼[EA]⎼> (y_pred) + ╲ + [Loss]⎼> (output) + ╱ + (y_true) + */ + + auto input = tgraph.addOperand(shape, type); + auto y_pred = tgraph.addOperand(shape, type); + auto y_true = tgraph.addOperand(shape, type); + auto output = tgraph.addOperand(shape, type); + + tgraph.addInput({input}); + tgraph.addInput({y_true}); + tgraph.addOutput({output}); + + addElementwiseActivationOperation(tgraph, {input}, {y_pred}); + addLossOperation(tgraph, {y_pred, y_true}, {output}); + + EXPECT_NO_THROW(tgraph.topolSortOperations()); + EXPECT_NO_THROW(tgraph.btopolSortOperations()); +} + +TEST(TrainableGraph, topological_sort_nonlinear) +{ + train::TrainableGraph tgraph; + + Shape shape{1, 2, 2, 1}; + TypeInfo type{DataType::FLOAT32}; + + /* + [EA]⎼> (lhs) + ╱ ╲ + (input) ⎼[EA]⎼> (split) [Add]⎼> (y_pred) + ╲ ╱ ╲ + [EA]⎼> (rhs) [Loss]⎼> (output) + ╱ + (y_true) + */ + + auto input = tgraph.addOperand(shape, type); + auto split = tgraph.addOperand(shape, type); + auto lhs = tgraph.addOperand(shape, type); + auto rhs = tgraph.addOperand(shape, type); + auto y_pred = tgraph.addOperand(shape, type); + auto y_true = tgraph.addOperand(shape, type); + auto output = tgraph.addOperand(shape, type); + + tgraph.addInput({input}); + tgraph.addInput({y_true}); + tgraph.addOutput({output}); + + addElementwiseActivationOperation(tgraph, {input}, {split}); + addElementwiseActivationOperation(tgraph, {split}, {lhs}); + addElementwiseActivationOperation(tgraph, {split}, {rhs}); + addAddOperation(tgraph, {lhs, rhs}, {y_pred}); + addLossOperation(tgraph, {y_pred, y_true}, {output}); + + EXPECT_NO_THROW(tgraph.topolSortOperations()); + EXPECT_NO_THROW(tgraph.btopolSortOperations()); +} + +TEST(TrainableGraph, neg_topological_sort_cycle) +{ + train::TrainableGraph tgraph; + + Shape shape{1, 2, 2, 1}; + TypeInfo type{DataType::FLOAT32}; + + /* + (input) ⎼[Add]⎼> (v) ⎼[EA] + | | + v + (u) <⎼[EA]⎼ (y_pred) + ╲ + [Loss]⎼> (output) + ╱ + (y_true) + */ + + auto input = tgraph.addOperand(shape, type); + auto u = tgraph.addOperand(shape, type); + auto v = tgraph.addOperand(shape, type); + auto y_pred = tgraph.addOperand(shape, type); + auto y_true = tgraph.addOperand(shape, type); + auto output = tgraph.addOperand(shape, type); + + tgraph.addInput({input}); + tgraph.addInput({y_true}); + tgraph.addOutput({output}); + + addAddOperation(tgraph, {input, u}, {v}); + addElementwiseActivationOperation(tgraph, {v}, {y_pred}); + addElementwiseActivationOperation(tgraph, {y_pred}, {u}); + addLossOperation(tgraph, {y_pred, y_true}, {output}); + + EXPECT_ANY_THROW(tgraph.topolSortOperations()); + EXPECT_ANY_THROW(tgraph.btopolSortOperations()); +} + +TEST(TrainableGraph, truncating_backward_topological_order_nonlinear) +{ + { + train::TrainableGraph tgraph; + + Shape shape{1, 2, 2, 1}; + TypeInfo type{DataType::FLOAT32}; + + /* + [EA1]⎼> (u) + ╱ ╲ + ╱ (weight1) ⎼[FC1]⎼> (v) + ╱ ╱ ╲ + ╱ (bias1) [Add]⎼> (y_pred) + (input) ╱ ╲ + ╲ ╱ [Loss]⎼> (output) + [EA2]⎼> (w) ╱ ╱ + ╲ ╱ (y_true) + (weight2) ⎼[FC2]⎼> (x) + ╱ + (bias2) + */ + + auto input = tgraph.addOperand(shape, type); + auto u = tgraph.addOperand(shape, type); + auto weight1 = tgraph.addOperand(shape, type); + auto bias1 = tgraph.addOperand(shape, type); + auto v = tgraph.addOperand(shape, type); + auto w = tgraph.addOperand(shape, type); + auto weight2 = tgraph.addOperand(shape, type); + auto bias2 = tgraph.addOperand(shape, type); + auto x = tgraph.addOperand(shape, type); + auto y_pred = tgraph.addOperand(shape, type); + auto y_true = tgraph.addOperand(shape, type); + auto output = tgraph.addOperand(shape, type); + + tgraph.addInput({input}); + tgraph.addInput({weight1}); + tgraph.addInput({bias1}); + tgraph.addInput({weight2}); + tgraph.addInput({bias2}); + tgraph.addInput({y_true}); + tgraph.addOutput({output}); + + auto ea1 = addElementwiseActivationOperation(tgraph, {input}, {u}); + auto fc1 = addFullyConnectedOperation(tgraph, {u, weight1, bias1}, {v}); + auto ea2 = addElementwiseActivationOperation(tgraph, {input}, {w}); + auto fc2 = addFullyConnectedOperation(tgraph, {w, weight2, bias2}, {x}); + auto add = addAddOperation(tgraph, {v, x}, {y_pred}); + auto loss = addLossOperation(tgraph, {y_pred, y_true}, {output}); + + std::vector<OperationIndex> expected_truncation_1{loss, add, fc1, fc2}; + std::vector<OperationIndex> expected_truncation_2{loss, add, fc2, fc1}; + std::vector<OperationIndex> truncation = + tgraph.truncateBackwardOrder(tgraph.btopolSortOperations()); + + ASSERT_TRUE(truncation == expected_truncation_1 || truncation == expected_truncation_2); + } + + { + train::TrainableGraph tgraph; + + Shape shape{1, 2, 2, 1}; + TypeInfo type{DataType::FLOAT32}; + + /* + (input1) ⎼[FC3]⎼> (r) ⎼⎼[Add]⎼> (s) ⎼[EA1]⎼> (u) + ╱ ╱ ╲ + (weight3) ╱ (weight1) ⎼[FC1]⎼> (v) + ╱ ╱ ╲ + ╱ ╱ ╲ + ╱ (bias1) [Add]⎼> (y_pred) + (input) ╱ ╲ + ╲ ╱ [Loss]⎼> (output) + ╲ ╱ ╱ + [Add]⎼> (t) ⎼[EA2]⎼> (w) ╱ ╱ + ╱ ╲ ╱ (y_true) + (input2) (weight2) ⎼[FC2]⎼> (x) + ╱ + (bias2) + */ + + auto input1 = tgraph.addOperand(shape, type); + auto weight3 = tgraph.addOperand(shape, type); + auto r = tgraph.addOperand(shape, type); + auto input = tgraph.addOperand(shape, type); + auto s = tgraph.addOperand(shape, type); + auto input2 = tgraph.addOperand(shape, type); + auto t = tgraph.addOperand(shape, type); + auto u = tgraph.addOperand(shape, type); + auto weight1 = tgraph.addOperand(shape, type); + auto bias1 = tgraph.addOperand(shape, type); + auto v = tgraph.addOperand(shape, type); + auto w = tgraph.addOperand(shape, type); + auto weight2 = tgraph.addOperand(shape, type); + auto bias2 = tgraph.addOperand(shape, type); + auto x = tgraph.addOperand(shape, type); + auto y_pred = tgraph.addOperand(shape, type); + auto y_true = tgraph.addOperand(shape, type); + auto output = tgraph.addOperand(shape, type); + + tgraph.addInput({input}); + tgraph.addInput({weight1}); + tgraph.addInput({bias1}); + tgraph.addInput({weight2}); + tgraph.addInput({bias2}); + tgraph.addInput({y_true}); + tgraph.addOutput({output}); + + auto fc3 = addFullyConnectedOperation(tgraph, {input1, weight3}, {r}); + auto add1 = addAddOperation(tgraph, {r, input}, {s}); + auto add2 = addAddOperation(tgraph, {input, input2}, {t}); + auto ea1 = addElementwiseActivationOperation(tgraph, {s}, {u}); + auto fc1 = addFullyConnectedOperation(tgraph, {u, weight1, bias1}, {v}); + auto ea2 = addElementwiseActivationOperation(tgraph, {t}, {w}); + auto fc2 = addFullyConnectedOperation(tgraph, {w, weight2, bias2}, {x}); + auto add = addAddOperation(tgraph, {v, x}, {y_pred}); + auto loss = addLossOperation(tgraph, {y_pred, y_true}, {output}); + + // This expected indices are base on dfs + std::vector<OperationIndex> expected_truncation_1{loss, add, fc1, ea1, add1, fc3, fc2}; + std::vector<OperationIndex> expected_truncation_2{loss, add, fc2, fc1, ea1, add1, fc3}; + std::vector<OperationIndex> truncation = + tgraph.truncateBackwardOrder(tgraph.btopolSortOperations()); + + ASSERT_TRUE(truncation == expected_truncation_1 || truncation == expected_truncation_2); + } +} + +TEST(TrainableGraph, essential_backward_topological_order_nonlinear) +{ + { + train::TrainableGraph tgraph; + + Shape shape{1, 2, 2, 1}; + TypeInfo type{DataType::FLOAT32}; + + /* + (input1) ⎼[FC3]⎼> (r) ⎼⎼[Add]⎼> (s) ⎼[EA1]⎼> (u) + ╱ ╱ ╲ + (weight3) ╱ (weight1) ⎼[FC1]⎼> (v) + ╱ ╱ ╲ + ╱ ╱ ╲ + ╱ (bias1) [Add]⎼> (y_pred) + (input) ╱ ╲ + ╲ ╱ [Loss]⎼> (output) + ╲ ╱ ╱ + [Add]⎼> (t) ⎼[EA2]⎼> (w) ╱ ╱ + ╱ ╲ ╱ (y_true) + (input2) (weight2) ⎼[FC2]⎼> (x) + ╱ + (bias2) + */ + + auto input1 = tgraph.addOperand(shape, type); + auto weight3 = tgraph.addOperand(shape, type); + auto r = tgraph.addOperand(shape, type); + auto input = tgraph.addOperand(shape, type); + auto s = tgraph.addOperand(shape, type); + auto input2 = tgraph.addOperand(shape, type); + auto t = tgraph.addOperand(shape, type); + auto u = tgraph.addOperand(shape, type); + auto weight1 = tgraph.addOperand(shape, type); + auto bias1 = tgraph.addOperand(shape, type); + auto v = tgraph.addOperand(shape, type); + auto w = tgraph.addOperand(shape, type); + auto weight2 = tgraph.addOperand(shape, type); + auto bias2 = tgraph.addOperand(shape, type); + auto x = tgraph.addOperand(shape, type); + auto y_pred = tgraph.addOperand(shape, type); + auto y_true = tgraph.addOperand(shape, type); + auto output = tgraph.addOperand(shape, type); + + tgraph.addInput({input}); + tgraph.addInput({weight1}); + tgraph.addInput({bias1}); + tgraph.addInput({weight2}); + tgraph.addInput({bias2}); + tgraph.addInput({y_true}); + tgraph.addOutput({output}); + + auto fc3 = addFullyConnectedOperation(tgraph, {input1, weight3}, {r}); + auto add1 = addAddOperation(tgraph, {r, input}, {s}); + auto add2 = addAddOperation(tgraph, {input, input2}, {t}); + auto ea1 = addElementwiseActivationOperation(tgraph, {s}, {u}); + auto fc1 = addFullyConnectedOperation(tgraph, {u, weight1, bias1}, {v}); + auto ea2 = addElementwiseActivationOperation(tgraph, {t}, {w}); + auto fc2 = addFullyConnectedOperation(tgraph, {w, weight2, bias2}, {x}); + auto add = addAddOperation(tgraph, {v, x}, {y_pred}); + auto loss = addLossOperation(tgraph, {y_pred, y_true}, {output}); + + tgraph.enableBackward(fc2); + tgraph.enableBackward(fc3); + + // These expected indices are base on dfs + std::vector<OperationIndex> expected_truncation_1{loss, add, fc1, ea1, add1, fc3, fc2}; + std::vector<OperationIndex> expected_truncation_2{loss, add, fc2, fc1, ea1, add1, fc3}; + std::vector<OperationIndex> essential = tgraph.essentialBackwardOrder(); + + ASSERT_TRUE(essential == expected_truncation_1 || essential == expected_truncation_2); + } +} diff --git a/runtime/onert/core/src/ir/train/TrainingInfo.cc b/runtime/onert/core/src/ir/train/TrainingInfo.cc new file mode 100644 index 000000000..102781173 --- /dev/null +++ b/runtime/onert/core/src/ir/train/TrainingInfo.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/TrainingInfo.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ + +bool TrainingInfo::isValid() const +{ + if (_batch_size == 0) + return false; + + if (_optimizer_info.optim_code == OptimizerCode::Undefined) + return false; + + if (_optimizer_info.learning_rate <= 0.0f) + return false; + + if (_loss_info.loss_code == LossCode::Undefined) + return false; + + if (_loss_info.reduction_type == LossReductionType::Undefined) + return false; + + // If there are invalid combination, add more condition-check here + return true; +} + +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/UseDefChain.cc b/runtime/onert/core/src/ir/train/UseDefChain.cc new file mode 100644 index 000000000..9cb9bb7c9 --- /dev/null +++ b/runtime/onert/core/src/ir/train/UseDefChain.cc @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/UseDefChain.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ + +void UseDefChain::insertTrainingUse(const TrainingOperationIndex &idx) { _uses.insert(idx); } + +void UseDefChain::removeTrainingUse(const TrainingOperationIndex &idx) { _uses.erase(idx); } + +void UseDefChain::insertTrainingDef(const TrainingOperationIndex &idx) +{ + // defs must be valid + assert(idx.valid()); + _defs.insert(idx); +} + +void UseDefChain::removeTrainingDef(const TrainingOperationIndex &idx) { _defs.erase(idx); } + +void UseDefChain::clearTrainingUseDefs() +{ + _uses.clear(); + _defs.clear(); +} + +bool UseDefChain::operator==(const UseDefChain &other) const +{ + return &_operand == &other._operand && _uses == other._uses && _defs == other._defs; +} + +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.cc b/runtime/onert/core/src/ir/train/UseDefGenerator.cc new file mode 100644 index 000000000..615b1650c --- /dev/null +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.cc @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "UseDefGenerator.h" + +#include "ir/train/TrainableGraph.h" +#include "ir/train/Index.h" +#include "../verifier/Verifier.h" + +#include <cassert> +#include <memory> + +// TODO Reduce duplicate code + +namespace onert +{ +namespace ir +{ +namespace train +{ + +UseDefGenerator::UseDefGenerator(const TrainableGraph &tgraph) + : _tgraph{tgraph}, _node_to_idx{}, _training_usedefs{} +{ + const auto order = _tgraph.topolSortOperations(); + for (const auto &index : order) + { + const auto &node = _tgraph.operation(index); + assert(_node_to_idx.find(&node) == _node_to_idx.end()); + _node_to_idx[&node] = index; + } + + // Check whether loss exists + assert(std::any_of(order.begin(), order.end(), + [&](const auto &index) { + return _tgraph.operation(index).opcode() == ir::OpCode::Loss; + }) && + "Loss does not exist"); +} + +UseDefChains UseDefGenerator::operator()() +{ + const auto &graph = _tgraph.graph(); + assert(ir::verifier::EdgeChecker().verify(graph)); + + _training_usedefs.clear(); + graph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &operand) { + // Initialize as emtpy UseDefChain + const auto empty_usedef_chain = UseDefChain{operand}; + _training_usedefs.emplace(TrainingOperandIndex{idx, true}, empty_usedef_chain); + _training_usedefs.emplace(TrainingOperandIndex{idx, false}, empty_usedef_chain); + }); + + initForForwardingNodes(); + + initForBackwardingNodes(); + + return _training_usedefs; +} + +void UseDefGenerator::visit(const train::operation::Loss &node) +{ + assert(_node_to_idx.find(&node) != _node_to_idx.end()); + const auto &op_index = _node_to_idx.at(&node); + const auto backwarding_op_index = TrainingOperationIndex{op_index, false}; + + for (const auto &in_index : node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED) + { + // Insert use of forwarding inputs + const auto in_forwarding_index = TrainingOperandIndex{in_index, true}; + insertUse(in_forwarding_index, backwarding_op_index); + } + + // Set def of backwarding(backprop) y_pred + const auto &y_pred_index = node.getInputs().at(train::operation::Loss::Input::Y_PRED); + assert(!_tgraph.operands().at(y_pred_index).isConstant()); + const auto y_pred_outgoing_index = TrainingOperandIndex{y_pred_index, false}; + insertBackPropDef(y_pred_outgoing_index, backwarding_op_index); + + // Set def of backwarding(backprop) y_true + const auto &y_true_index = node.getInputs().at(train::operation::Loss::Input::Y_TRUE); + assert(!_tgraph.operands().at(y_true_index).isConstant()); + const auto y_true_outgoing_index = TrainingOperandIndex{y_true_index, false}; + insertBackPropDef(y_true_outgoing_index, backwarding_op_index); + + // Remove use of backwarding output + const auto &out_index = node.getOutputs().at(0); + const auto incoming_index = TrainingOperandIndex{out_index, false}; + auto &usedef_chain = _training_usedefs.at(incoming_index); + usedef_chain.removeTrainingUse(backwarding_op_index); +} + +void UseDefGenerator::insertUse(const TrainingOperandIndex &operand_index, + const TrainingOperationIndex &op_index) +{ + assert(_training_usedefs.find(operand_index) != _training_usedefs.end()); + auto &usedef_chain = _training_usedefs.at(operand_index); + usedef_chain.insertTrainingUse(op_index); +} + +void UseDefGenerator::insertDef(const TrainingOperandIndex &operand_index, + const TrainingOperationIndex &op_index) +{ + assert(operand_index.valid()); + + assert(_training_usedefs.find(operand_index) != _training_usedefs.end()); + auto &usedef_chain = _training_usedefs.at(operand_index); + usedef_chain.insertTrainingDef(op_index); +} + +void UseDefGenerator::insertBackPropDef(const TrainingOperandIndex &operand_index, + const TrainingOperationIndex &op_index) +{ + // NOTE There is no need to set def of constant backwarding(backprop) inputs + // because it won't be back-propagated. + if (!_tgraph.operands().at(operand_index.index()).isConstant()) + { + insertDef(operand_index, op_index); + } +} + +void UseDefGenerator::initForForwardingNodes() +{ + // Initialize training def-uses of forwarding operands for only forwarding nodes + // (i.e. forwarding nodes that do not have any backwarding node) + _tgraph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &operand) { + // Append forwarding def-uses as it is + const bool is_forward = true; + const auto forwarding_operand_index = TrainingOperandIndex{idx, is_forward}; + + const auto def = operand.getDef(); + if (def.valid()) + { + insertDef(forwarding_operand_index, TrainingOperationIndex{def, is_forward}); + auto &usedef_chain = _training_usedefs.at(forwarding_operand_index); + usedef_chain.insertTrainingDef(TrainingOperationIndex{def, is_forward}); + } + + assert(_training_usedefs.at(forwarding_operand_index).getTrainingUses().size() == 0); + const auto uses = operand.getUses(); + for (const auto &use : uses) + insertUse(forwarding_operand_index, TrainingOperationIndex{use, is_forward}); + }); +} + +void UseDefGenerator::initForBackwardingNodes() +{ + const auto backward_order = _tgraph.essentialBackwardOrder(); + // Initialize training uses of forwarding operands and def-uses of backwarding operands for + // backwarding nodes (i.e. backwarding nodes that do not have any forwarding node) + for (const auto &op_index : backward_order) + { + const auto &node = _tgraph.operation(op_index); + + // Insert use of backwarding operands(only output) + { + if (node.getOutputs().size() > 1) + throw std::runtime_error( + "UseDefGenerator does not support multiple outputs of training operation"); + + const auto &output = node.getOutputs().at(0); + const auto backwarding_op_index = TrainingOperationIndex{op_index, false}; + const auto incoming_index = TrainingOperandIndex{output, false}; + insertUse(incoming_index, backwarding_op_index); + } + + // Insert uses of forwarding operands and insert defs of backwarding operands + node.accept(*this); + } +} + +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.h b/runtime/onert/core/src/ir/train/UseDefGenerator.h new file mode 100644 index 000000000..369d9a223 --- /dev/null +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__ +#define __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__ + +#include "ir/train/TrainableOperationVisitor.h" + +#include "ir/train/UseDefChains.h" +#include "ir/train/Operations.Include.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +class TrainableGraph; +} // namespace train +} // namespace ir +} // namespace onert + +namespace onert +{ +namespace ir +{ +namespace train +{ + +struct UseDefGeneratorBase : public TrainableOperationVisitor +{ + virtual ~UseDefGeneratorBase() = default; + +protected: +#define OP(InternalName) \ + virtual void visit(const operation::InternalName &) override \ + { \ + throw std::runtime_error("UseDefGenerator: NYI for operation '" #InternalName "'"); \ + } +#include "ir/train/Operations.lst" +#undef OP +}; + +class UseDefGenerator : public UseDefGeneratorBase +{ +public: + UseDefGenerator(void) = delete; + UseDefGenerator(const TrainableGraph &tgraph); + +public: + UseDefChains operator()(); + +public: + void visit(const train::operation::Loss &node) override; + +private: + void insertUse(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index); + void insertDef(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index); + void insertBackPropDef(const TrainingOperandIndex &operand_index, + const TrainingOperationIndex &op_index); + void initForForwardingNodes(); + void initForBackwardingNodes(); + +private: + const TrainableGraph &_tgraph; + std::unordered_map<const ITrainableOperation *, OperationIndex> _node_to_idx; + UseDefChains _training_usedefs; +}; + +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__ diff --git a/runtime/onert/core/src/ir/train/operation/BinaryArithmetic.cc b/runtime/onert/core/src/ir/train/operation/BinaryArithmetic.cc new file mode 100644 index 000000000..473d38735 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/BinaryArithmetic.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/BinaryArithmetic.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> BinaryArithmetic::clone() const +{ + return std::make_unique<BinaryArithmetic>(*this); +} + +void BinaryArithmetic::accept(OperationVisitor &v) const { v.visit(*this); } + +void BinaryArithmetic::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +BinaryArithmetic::BinaryArithmetic(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Conv2D.cc b/runtime/onert/core/src/ir/train/operation/Conv2D.cc new file mode 100644 index 000000000..923861ae3 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Conv2D.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Conv2D.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Conv2D::clone() const +{ + return std::make_unique<Conv2D>(*this); +} + +void Conv2D::accept(OperationVisitor &v) const { v.visit(*this); } + +void Conv2D::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Conv2D::Conv2D(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/DepthwiseConv2D.cc b/runtime/onert/core/src/ir/train/operation/DepthwiseConv2D.cc new file mode 100644 index 000000000..2a7289619 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/DepthwiseConv2D.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/DepthwiseConv2D.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> DepthwiseConv2D::clone() const +{ + return std::make_unique<DepthwiseConv2D>(*this); +} + +void DepthwiseConv2D::accept(OperationVisitor &v) const { v.visit(*this); } + +void DepthwiseConv2D::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +DepthwiseConv2D::DepthwiseConv2D(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc b/runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc new file mode 100644 index 000000000..1dae3f674 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/ElementwiseActivation.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> ElementwiseActivation::clone() const +{ + return std::make_unique<ElementwiseActivation>(*this); +} + +void ElementwiseActivation::accept(OperationVisitor &v) const { v.visit(*this); } + +void ElementwiseActivation::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +ElementwiseActivation::ElementwiseActivation(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/FullyConnected.cc b/runtime/onert/core/src/ir/train/operation/FullyConnected.cc new file mode 100644 index 000000000..a26f7c489 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/FullyConnected.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/FullyConnected.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> FullyConnected::clone() const +{ + return std::make_unique<FullyConnected>(*this); +} + +void FullyConnected::accept(OperationVisitor &v) const { v.visit(*this); } + +void FullyConnected::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +FullyConnected::FullyConnected(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Loss.cc b/runtime/onert/core/src/ir/train/operation/Loss.cc new file mode 100644 index 000000000..3a89e0ff6 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Loss.cc @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Loss.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +#include <misc/polymorphic_downcast.h> + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Loss::clone() const { return std::make_unique<Loss>(*this); } + +void Loss::accept(OperationVisitor &v) const { v.visit(*this); } + +void Loss::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Loss::Loss(const OperationType &operation, const LossInfo ¶m) + : OperationType{operation.getInputs(), operation.getOutputs()}, _param{param} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Pad.cc b/runtime/onert/core/src/ir/train/operation/Pad.cc new file mode 100644 index 000000000..56394f5ef --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Pad.cc @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Pad.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Pad::clone() const { return std::make_unique<Pad>(*this); } + +void Pad::accept(OperationVisitor &v) const { v.visit(*this); } + +void Pad::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Pad::Pad(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Permute.cc b/runtime/onert/core/src/ir/train/operation/Permute.cc new file mode 100644 index 000000000..adc23aa49 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Permute.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Permute.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Permute::clone() const +{ + return std::make_unique<Permute>(*this); +} + +void Permute::accept(OperationVisitor &v) const { v.visit(*this); } + +void Permute::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Permute::Permute(const OperationType &operation) + : OperationType{operation.getInputs().at(0), operation.getOutputs().at(0), + operation.getPermuteType()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Pool2D.cc b/runtime/onert/core/src/ir/train/operation/Pool2D.cc new file mode 100644 index 000000000..021574f19 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Pool2D.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Pool2D.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Pool2D::clone() const +{ + return std::make_unique<Pool2D>(*this); +} + +void Pool2D::accept(OperationVisitor &v) const { v.visit(*this); } + +void Pool2D::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Pool2D::Pool2D(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Reduce.cc b/runtime/onert/core/src/ir/train/operation/Reduce.cc new file mode 100644 index 000000000..51986a0c2 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Reduce.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Reduce.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Reduce::clone() const +{ + return std::make_unique<Reduce>(*this); +} + +void Reduce::accept(OperationVisitor &v) const { v.visit(*this); } + +void Reduce::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Reduce::Reduce(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Reshape.cc b/runtime/onert/core/src/ir/train/operation/Reshape.cc new file mode 100644 index 000000000..c76158607 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Reshape.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Reshape.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Reshape::clone() const +{ + return std::make_unique<Reshape>(*this); +} + +void Reshape::accept(OperationVisitor &v) const { v.visit(*this); } + +void Reshape::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Reshape::Reshape(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Softmax.cc b/runtime/onert/core/src/ir/train/operation/Softmax.cc new file mode 100644 index 000000000..dbd403879 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Softmax.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Softmax.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Softmax::clone() const +{ + return std::make_unique<Softmax>(*this); +} + +void Softmax::accept(OperationVisitor &v) const { v.visit(*this); } + +void Softmax::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Softmax::Softmax(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc b/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc new file mode 100644 index 000000000..e3472ec51 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc @@ -0,0 +1,1239 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/UntrainableOperation.h" + +#include "ir/Operations.Include.h" + +#include <gtest/gtest.h> + +using namespace ::onert::ir; + +operation::AddN generateAddN() +{ + return operation::AddN{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; +} + +operation::ArgMinMax generateArgMinMax() +{ + operation::ArgMinMax::Param param; + param.output_type = DataType::FLOAT32; + + return operation::ArgMinMax{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param}; +} + +operation::BatchMatMul generateBatchMatMul() +{ + operation::BatchMatMul::Param param; + param.adj_x = true; + param.adj_y = true; + + return operation::BatchMatMul{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param}; +} + +operation::BatchToSpaceND generateBatchToSpaceND() +{ + return operation::BatchToSpaceND{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; +} + +operation::BCQFullyConnected generateBCQFullyConnected() +{ + operation::BCQFullyConnected::Param param; + param.activation = Activation::NONE; + param.weights_hidden_size = 1; + + return operation::BCQFullyConnected{OperandIndexSequence{1, 2, 3, 4, 5}, OperandIndexSequence{0}, + param}; +} + +operation::BCQGather generateBCQGather() +{ + operation::BCQGather::Param param; + param.axis = 0; + param.input_hidden_size = 1; + + return operation::BCQGather{OperandIndexSequence{1, 2, 3, 4}, OperandIndexSequence{0}, param}; +} + +operation::BinaryArithmetic generateBinaryArithmetic() +{ + operation::BinaryArithmetic::Param param; + param.activation = Activation::NONE; + param.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD; + + return operation::BinaryArithmetic{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param}; +} + +operation::BroadcastTo generateBroadcastTo() +{ + return operation::BroadcastTo{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; +} + +operation::Bulk generateBulk() +{ + operation::Bulk::Param param; + param.binary_path = ""; + param.origin_input_shapes = std::vector<onert::ir::Shape>{}; + param.origin_output_shapes = std::vector<onert::ir::Shape>{}; + + return operation::Bulk{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param}; +} + +operation::Comparison generateComparison() +{ + operation::Comparison::Param param; + param.comparison_type = operation::Comparison::ComparisonType::Equal; + + return operation::Comparison{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param}; +} + +operation::Concat generateConcat() +{ + operation::Concat::Param param; + param.axis = 0; + + return operation::Concat{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param}; +} + +operation::Conv2D generateConv2D() +{ + operation::Conv2D::Param param; + param.activation = Activation::NONE; + param.dilation = Dilation{}; + param.padding = Padding{}; + param.stride = Stride{}; + + return operation::Conv2D{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param}; +} + +operation::ConvertFp16ToFp32 generateConvertFp16ToFp32() +{ + return operation::ConvertFp16ToFp32{OperandIndexSequence{1}, OperandIndexSequence{0}}; +} + +operation::ConvertFp32ToFp16 generateConvertFp32ToFp16() +{ + return operation::ConvertFp32ToFp16{OperandIndexSequence{1}, OperandIndexSequence{0}}; +} + +operation::Custom generateCustom() +{ + return operation::Custom{OperandConstraint::createExact(1u), OperandIndexSequence{1}, + OperandIndexSequence{0}, std::string("id"), + operation::Custom::Userdata{}}; +} + +operation::DepthToSpace generateDepthToSpace() +{ + operation::DepthToSpace::Param param; + param.block_size = 1; + + return operation::DepthToSpace{OperandIndexSequence{1}, OperandIndexSequence{0}, param}; +} + +operation::DepthwiseConv2D generateDepthwiseConv2D() +{ + operation::DepthwiseConv2D::Param param; + param.activation = Activation::NONE; + param.dilation = Dilation{}; + param.multiplier = 1u; + param.padding = Padding{}; + param.stride = Stride{}; + + return operation::DepthwiseConv2D{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param}; +} + +operation::DetectionPostProcess generateDetectionPostProcess() +{ + operation::DetectionPostProcess::Param param; + + return operation::DetectionPostProcess{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, + param}; +} + +operation::Einsum generateEinsum() +{ + operation::Einsum::Param param; + param.equation = ""; + + return operation::Einsum{OperandIndexSequence{1}, OperandIndexSequence{0}, param}; +} + +operation::ElementwiseActivation generateElementwiseActivation() +{ + operation::ElementwiseActivation::Param param; + param.alpha = 0.f; + param.beta = 0.f; + param.op_type = operation::ElementwiseActivation::Type::ELU; + + return operation::ElementwiseActivation{OperandIndexSequence{1}, OperandIndexSequence{0}, param}; +} + +operation::ElementwiseBinary generateElementwiseBinary() +{ + operation::ElementwiseBinary::Param param; + param.op_type = operation::ElementwiseBinary::ElementwiseBinaryType::FLOOR_DIV; + + return operation::ElementwiseBinary{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param}; +} + +operation::ElementwiseUnary generateElementwiseUnary() +{ + operation::ElementwiseUnary::Param param; + param.op_type = operation::ElementwiseUnary::Type::ABS; + + return operation::ElementwiseUnary{OperandIndexSequence{1}, OperandIndexSequence{0}, param}; +} + +operation::EmbeddingLookup generateEmbeddingLookup() +{ + return operation::EmbeddingLookup{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; +} + +operation::ExpandDims generateExpandDims() +{ + return operation::ExpandDims{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; +} + +operation::Fill generateFill() +{ + return operation::Fill{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; +} + +operation::FullyConnected generateFullyConnected() +{ + operation::FullyConnected::Param param; + param.activation = Activation::NONE; + param.weights_format = FullyConnectedWeightsFormat::Default; + + return operation::FullyConnected{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param}; +} + +operation::FusedBatchNorm generateFusedBatchNorm() +{ + operation::FusedBatchNorm::Param param; + param.is_training = false; + param.epsilon = 0.f; + param.data_format = ""; + + return operation::FusedBatchNorm{OperandIndexSequence{1, 2, 3, 4, 5}, OperandIndexSequence{0}, + param}; +} + +operation::Gather generateGather() +{ + operation::Gather::Param param; + param.axis = 0; + + return operation::Gather{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param}; +} + +operation::HashtableLookup generateHashtableLookup() +{ + return operation::HashtableLookup{OperandIndexSequence{2, 3, 4}, OperandIndexSequence{0, 1}}; +} + +operation::If generateIf() +{ + operation::If::Param param; + param.else_subg_index = 1; + param.then_subg_index = 2; + + return operation::If{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param}; +} + +operation::InstanceNorm generateInstanceNorm() +{ + operation::InstanceNorm::Param param; + param.activation = Activation::NONE; + param.epsilon = 0.f; + + return operation::InstanceNorm{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param}; +} + +operation::L2Normalization generateL2Normalization() +{ + return operation::L2Normalization{OperandIndexSequence{1}, OperandIndexSequence{0}}; +} + +operation::LocalResponseNormalization generateLocalResponseNormalization() +{ + operation::LocalResponseNormalization::Param param; + param.alpha = 0.f; + param.beta = 0.f; + param.bias = 0.f; + param.radius = 1; + + return operation::LocalResponseNormalization{OperandIndexSequence{1}, OperandIndexSequence{0}, + param}; +} + +operation::LogSoftmax generateLogSoftmax() +{ + operation::LogSoftmax::Param param; + param.axis = 0; + param.beta = 0.f; + + return operation::LogSoftmax{OperandIndexSequence{1}, OperandIndexSequence{0}, param}; +} + +operation::LSTM generateLSTM() +{ + operation::LSTM::Param param; + param.activation = Activation::NONE; + param.cell_threshold = 1.f; + param.projection_threshold = 1.f; + param.time_major = true; + + return operation::LSTM{ + OperandIndexSequence{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, + OperandIndexSequence{0}, param}; +} + +operation::MatrixBandPart generateMatrixBandPart() +{ + return operation::MatrixBandPart{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}}; +} + +operation::OneHot generateOneHot() +{ + operation::OneHot::Param param; + param.axis = 0; + + return operation::OneHot{OperandIndexSequence{1, 2, 3, 4}, OperandIndexSequence{0}, param}; +} + +operation::Pack generatePack() +{ + operation::Pack::Param param; + param.axis = 0; + param.num = 1; + + return operation::Pack{OperandIndexSequence{1}, OperandIndexSequence{0}, param}; +} + +operation::Pad generatePad() +{ + return operation::Pad{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}}; +} + +operation::Permute generatePermute() +{ + return operation::Permute{OperandIndex{1}, OperandIndex{0}, operation::Permute::Type::COPY}; +} + +operation::Pool2D generatePool2D() +{ + operation::Pool2D::Param param; + param.activation = Activation::NONE; + param.kh = 1; + param.kw = 1; + param.op_type = operation::Pool2D::PoolType::AVG; + param.padding = Padding{}; + param.stride = Stride{}; + + return operation::Pool2D{OperandIndexSequence{1}, OperandIndexSequence{0}, param}; +} + +operation::Pow generatePow() +{ + return operation::Pow{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; +} + +operation::PReLU generatePReLU() +{ + return operation::PReLU{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; +} + +operation::Range generateRange() +{ + return operation::Range{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}}; +} + +operation::Rank generateRank() +{ + return operation::Rank{OperandIndexSequence{1}, OperandIndexSequence{0}}; +} + +operation::Reduce generateReduce() +{ + operation::Reduce::Param param; + param.keep_dims = true; + param.reduce_type = operation::Reduce::ReduceType::ALL; + + return operation::Reduce{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param}; +} + +operation::Reshape generateReshape() +{ + operation::Reshape::Param param; + param.new_shape = std::vector<int32_t>{1}; + + return operation::Reshape{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param}; +} + +operation::ResizeBilinear generateResizeBilinear() +{ + operation::ResizeBilinear::Param param; + param.align_corners = true; + param.half_pixel_centers = true; + param.height_out = 1; + param.width_out = 1; + + return operation::ResizeBilinear{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param}; +} + +operation::ResizeNearestNeighbor generateResizeNearestNeighbor() +{ + operation::ResizeNearestNeighbor::Param param; + param.align_corners = true; + param.height_out = 1; + param.width_out = 1; + + return operation::ResizeNearestNeighbor{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, + param}; +} + +operation::Reverse generateReverse() +{ + return operation::Reverse{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; +} + +operation::RNN generateRNN() +{ + operation::RNN::Param param; + param.activation = Activation::NONE; + + return operation::RNN{OperandIndexSequence{1, 2, 3, 4, 5}, OperandIndexSequence{0}, param}; +} + +operation::Select generateSelect() +{ + return operation::Select{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}}; +} + +operation::Shape generateShape() +{ + return operation::Shape{OperandIndexSequence{1}, OperandIndexSequence{0}}; +} + +operation::Slice generateSlice() +{ + return operation::Slice{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}}; +} + +operation::Softmax generateSoftmax() +{ + operation::Softmax::Param param; + param.beta = 0.1f; + + return operation::Softmax{OperandIndexSequence{1}, OperandIndexSequence{0}, param}; +} + +operation::SpaceToBatchND generateSpaceToBatchND() +{ + return operation::SpaceToBatchND{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}}; +} + +operation::SpaceToDepth generateSpaceToDepth() +{ + operation::SpaceToDepth::Param param; + param.block_size = 1; + + return operation::SpaceToDepth{OperandIndexSequence{1}, OperandIndexSequence{0}, param}; +} + +operation::Split generateSplit() +{ + operation::Split::Param param; + param.num_splits = 1; + + return operation::Split{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param}; +} + +operation::SplitV generateSplitV() +{ + operation::SplitV::Param param; + param.num_splits = 1; + + return operation::SplitV{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param}; +} + +operation::SquaredDifference generateSquaredDifference() +{ + return operation::SquaredDifference{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; +} + +operation::Squeeze generateSqueeze() +{ + operation::Squeeze::Param param; + param.dims[0] = 1; + param.ndim = 1; + + return operation::Squeeze{OperandIndexSequence{1}, OperandIndexSequence{0}, param}; +} + +operation::StatelessRandomUniform generateStatelessRandomUniform() +{ + return operation::StatelessRandomUniform{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; +} + +operation::StridedSlice generateStridedSlice() +{ + operation::StridedSlice::Param param; + param.begin_mask = 1; + param.end_mask = 1; + param.shrink_axis_mask = 1; + + return operation::StridedSlice{OperandIndexSequence{1, 2, 3, 4}, OperandIndexSequence{0}, param}; +} + +operation::Tile generateTile() +{ + return operation::Tile{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; +} + +operation::TopKV2 generateTopKV2() +{ + operation::TopKV2::Param param; + param.k = 1; + + return operation::TopKV2{OperandIndexSequence{1}, OperandIndexSequence{0}, param}; +} + +operation::Transpose generateTranspose() +{ + return operation::Transpose{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; +} + +operation::TransposeConv generateTransposeConv() +{ + operation::TransposeConv::Param param; + param.padding = Padding(); + param.stride = Stride(); + + return operation::TransposeConv{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param}; +} + +operation::Unpack generateUnpack() +{ + operation::Unpack::Param param; + param.axis = 0; + param.num = 1; + + return operation::Unpack{OperandIndexSequence{1}, OperandIndexSequence{0}, param}; +} + +operation::While generateWhile() +{ + operation::While::Param param; + param.cond_subg_index = 1; + param.body_subg_index = 2; + + return operation::While{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param}; +} + +class MockOperationVisitor : public OperationVisitor +{ +public: + void invoke(Operation &op) { op.accept(*this); } + +#define OP(InternalName) \ + virtual void visit(const operation::InternalName &) override { visit_flag = true; } +#include "ir/Operations.lst" +#undef OP + +public: + // TODO Replace this flag with using GMOCK if necessary + bool visit_flag{false}; +}; + +template <typename OperationType> auto generateUntrainableOperation(const OperationType &op) +{ + return std::make_unique<train::operation::UntrainableOperation<OperationType>>(op); +} + +template <typename OperationType> void verifyOp(const OperationType &op) +{ + auto untrainable = generateUntrainableOperation(op); + EXPECT_EQ(untrainable->opcode(), op.opcode()); + EXPECT_EQ(untrainable->getInputs(), op.getInputs()); + EXPECT_EQ(untrainable->getOutputs(), op.getOutputs()); + + // Check clone + auto clone = untrainable->clone(); + EXPECT_TRUE(clone != nullptr); + EXPECT_EQ(clone->hasTrainableParameter(), untrainable->hasTrainableParameter()); + EXPECT_EQ(clone->opcode(), untrainable->opcode()); + EXPECT_EQ(clone->getInputs(), untrainable->getInputs()); + EXPECT_EQ(clone->getOutputs(), untrainable->getOutputs()); + + // Check downcast + const auto derived = + dynamic_cast<train::operation::UntrainableOperation<OperationType> *>(clone.get()); + EXPECT_TRUE(derived != nullptr); + EXPECT_EQ(clone->hasTrainableParameter(), untrainable->hasTrainableParameter()); + EXPECT_EQ(derived->opcode(), op.opcode()); + EXPECT_EQ(derived->getInputs(), op.getInputs()); + EXPECT_EQ(derived->getOutputs(), op.getOutputs()); + + // Check visitor + MockOperationVisitor visitor; + + visitor.visit_flag = false; + visitor.invoke(*untrainable); + EXPECT_TRUE(visitor.visit_flag); +} + +TEST(UntrainableOperation, testAllOps) +{ + const auto addn = generateAddN(); + verifyOp(addn); + + const auto argminmax = generateArgMinMax(); + verifyOp(argminmax); + + const auto batch_matmul = generateBatchMatMul(); + verifyOp(batch_matmul); + + const auto batch_to_spacend = generateBatchToSpaceND(); + verifyOp(batch_to_spacend); + + const auto bcq_fc = generateBCQFullyConnected(); + verifyOp(bcq_fc); + + const auto bcq_gather = generateBCQGather(); + verifyOp(bcq_gather); + + const auto binary_arithmetic = generateBinaryArithmetic(); + verifyOp(binary_arithmetic); + + const auto broadcast = generateBroadcastTo(); + verifyOp(broadcast); + + const auto bulk = generateBulk(); + verifyOp(bulk); + + const auto comparison = generateComparison(); + verifyOp(comparison); + + const auto concat = generateConcat(); + verifyOp(concat); + + const auto conv2d = generateConv2D(); + verifyOp(conv2d); + + const auto fp16_to_fp32 = generateConvertFp16ToFp32(); + verifyOp(fp16_to_fp32); + + const auto fp32_to_fp16 = generateConvertFp32ToFp16(); + verifyOp(fp32_to_fp16); + + const auto custom = generateCustom(); + verifyOp(custom); + + const auto depth_to_space = generateDepthToSpace(); + verifyOp(depth_to_space); + + const auto depthwise_conv2d = generateDepthwiseConv2D(); + verifyOp(depthwise_conv2d); + + const auto detection = generateDetectionPostProcess(); + verifyOp(detection); + + const auto einsum = generateEinsum(); + verifyOp(einsum); + + const auto activation = generateElementwiseActivation(); + verifyOp(activation); + + const auto binary = generateElementwiseBinary(); + verifyOp(binary); + + const auto unary = generateElementwiseUnary(); + verifyOp(unary); + + const auto embed = generateEmbeddingLookup(); + verifyOp(embed); + + const auto expand_dims = generateExpandDims(); + verifyOp(expand_dims); + + const auto fill = generateFill(); + verifyOp(fill); + + const auto fc = generateFullyConnected(); + verifyOp(fc); + + const auto fused_batch_norm = generateFusedBatchNorm(); + verifyOp(fused_batch_norm); + + const auto gather = generateGather(); + verifyOp(gather); + + const auto hashtable = generateHashtableLookup(); + verifyOp(hashtable); + + const auto if_op = generateIf(); + verifyOp(if_op); + + const auto in_norm = generateInstanceNorm(); + verifyOp(in_norm); + + const auto l2_norm = generateL2Normalization(); + verifyOp(l2_norm); + + const auto local_norm = generateLocalResponseNormalization(); + verifyOp(local_norm); + + const auto log_softmax = generateLogSoftmax(); + verifyOp(log_softmax); + + const auto lstm = generateLSTM(); + verifyOp(lstm); + + const auto maxrix_band_part = generateMatrixBandPart(); + verifyOp(maxrix_band_part); + + const auto one_hot = generateOneHot(); + verifyOp(one_hot); + + const auto pack = generatePack(); + verifyOp(pack); + + const auto pad = generatePad(); + verifyOp(pad); + + const auto permute = generatePermute(); + verifyOp(permute); + + const auto pool2d = generatePool2D(); + verifyOp(pool2d); + + const auto pow = generatePow(); + verifyOp(pow); + + const auto prelu = generatePReLU(); + verifyOp(prelu); + + const auto range = generateRange(); + verifyOp(range); + + const auto rank = generateRank(); + verifyOp(rank); + + const auto reduce = generateReduce(); + verifyOp(reduce); + + const auto reshape = generateReshape(); + verifyOp(reshape); + + const auto resize_bilinear = generateResizeBilinear(); + verifyOp(resize_bilinear); + + const auto resize_nearest_neighbor = generateResizeNearestNeighbor(); + verifyOp(resize_nearest_neighbor); + + const auto reverse = generateReverse(); + verifyOp(reverse); + + const auto rnn = generateRNN(); + verifyOp(rnn); + + const auto select = generateSelect(); + verifyOp(select); + + const auto shape = generateShape(); + verifyOp(shape); + + const auto slice = generateSlice(); + verifyOp(slice); + + const auto softmax = generateSoftmax(); + verifyOp(softmax); + + const auto space_to_batchnd = generateSpaceToBatchND(); + verifyOp(space_to_batchnd); + + const auto space_to_depth = generateSpaceToDepth(); + verifyOp(space_to_depth); + + const auto split = generateSplit(); + verifyOp(split); + + const auto splitv = generateSplitV(); + verifyOp(splitv); + + const auto squared_diff = generateSquaredDifference(); + verifyOp(squared_diff); + + const auto squeeze = generateSqueeze(); + verifyOp(squeeze); + + const auto stateless_random_uniform = generateStatelessRandomUniform(); + verifyOp(stateless_random_uniform); + + const auto strided_slice = generateStridedSlice(); + verifyOp(strided_slice); + + const auto tile = generateTile(); + verifyOp(tile); + + const auto topkv2 = generateTopKV2(); + verifyOp(topkv2); + + const auto transpose = generateTranspose(); + verifyOp(transpose); + + const auto transpose_conv = generateTransposeConv(); + verifyOp(transpose_conv); + + const auto unpack = generateUnpack(); + verifyOp(unpack); + + const auto while_op = generateWhile(); + verifyOp(while_op); +} + +class MockTrainableOperationVisitor : public train::TrainableOperationVisitor +{ +public: + void invoke(train::ITrainableOperation &op) { op.accept(*this); } + +#define OP(InternalName) \ + virtual void visit(const train::operation::InternalName &) override {} +#include "ir/train/ITrainableOperation.h" +#undef OP +}; + +TEST(UntrainableOperation, neg_TrainableOperationVisitor) +{ + MockTrainableOperationVisitor visitor; + + { + const auto addn = generateAddN(); + auto untrainable = generateUntrainableOperation(addn); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + auto argminmax = generateArgMinMax(); + auto untrainable = generateUntrainableOperation(argminmax); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto batch_matmul = generateBatchMatMul(); + auto untrainable = generateUntrainableOperation(batch_matmul); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto batch_to_spacend = generateBatchToSpaceND(); + auto untrainable = generateUntrainableOperation(batch_to_spacend); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto bcq_fc = generateBCQFullyConnected(); + auto untrainable = generateUntrainableOperation(bcq_fc); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto bcq_gather = generateBCQGather(); + auto untrainable = generateUntrainableOperation(bcq_gather); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto binary_arithmetic = generateBinaryArithmetic(); + auto untrainable = generateUntrainableOperation(binary_arithmetic); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto broadcast = generateBroadcastTo(); + auto untrainable = generateUntrainableOperation(broadcast); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto bulk = generateBulk(); + auto untrainable = generateUntrainableOperation(bulk); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto comparison = generateComparison(); + auto untrainable = generateUntrainableOperation(comparison); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto concat = generateConcat(); + auto untrainable = generateUntrainableOperation(concat); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto conv2d = generateConv2D(); + auto untrainable = generateUntrainableOperation(conv2d); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto fp16_to_fp32 = generateConvertFp16ToFp32(); + auto untrainable = generateUntrainableOperation(fp16_to_fp32); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto fp32_to_fp16 = generateConvertFp32ToFp16(); + auto untrainable = generateUntrainableOperation(fp32_to_fp16); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto custom = generateCustom(); + auto untrainable = generateUntrainableOperation(custom); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto depth_to_space = generateDepthToSpace(); + auto untrainable = generateUntrainableOperation(depth_to_space); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto depthwise_conv2d = generateDepthwiseConv2D(); + auto untrainable = generateUntrainableOperation(depthwise_conv2d); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto detection = generateDetectionPostProcess(); + auto untrainable = generateUntrainableOperation(detection); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto einsum = generateEinsum(); + auto untrainable = generateUntrainableOperation(einsum); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto activation = generateElementwiseActivation(); + auto untrainable = generateUntrainableOperation(activation); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto binary = generateElementwiseBinary(); + auto untrainable = generateUntrainableOperation(binary); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto unary = generateElementwiseUnary(); + auto untrainable = generateUntrainableOperation(unary); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto embed = generateEmbeddingLookup(); + auto untrainable = generateUntrainableOperation(embed); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto expand_dims = generateExpandDims(); + auto untrainable = generateUntrainableOperation(expand_dims); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto fill = generateFill(); + auto untrainable = generateUntrainableOperation(fill); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto fc = generateFullyConnected(); + auto untrainable = generateUntrainableOperation(fc); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto fused_batch_norm = generateFusedBatchNorm(); + auto untrainable = generateUntrainableOperation(fused_batch_norm); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto gather = generateGather(); + auto untrainable = generateUntrainableOperation(gather); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto hashtable = generateHashtableLookup(); + auto untrainable = generateUntrainableOperation(hashtable); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto if_op = generateIf(); + auto untrainable = generateUntrainableOperation(if_op); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto in_norm = generateInstanceNorm(); + auto untrainable = generateUntrainableOperation(in_norm); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto l2_norm = generateL2Normalization(); + auto untrainable = generateUntrainableOperation(l2_norm); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto local_norm = generateLocalResponseNormalization(); + auto untrainable = generateUntrainableOperation(local_norm); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto log_softmax = generateLogSoftmax(); + auto untrainable = generateUntrainableOperation(log_softmax); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto lstm = generateLSTM(); + auto untrainable = generateUntrainableOperation(lstm); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto matrix_band_part = generateMatrixBandPart(); + auto untrainable = generateUntrainableOperation(matrix_band_part); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto one_hot = generateOneHot(); + auto untrainable = generateUntrainableOperation(one_hot); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto pack = generatePack(); + auto untrainable = generateUntrainableOperation(pack); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto pad = generatePad(); + auto untrainable = generateUntrainableOperation(pad); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto permute = generatePermute(); + auto untrainable = generateUntrainableOperation(permute); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto pool2d = generatePool2D(); + auto untrainable = generateUntrainableOperation(pool2d); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto pow = generatePow(); + auto untrainable = generateUntrainableOperation(pow); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto prelu = generatePReLU(); + auto untrainable = generateUntrainableOperation(prelu); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto range = generateRange(); + auto untrainable = generateUntrainableOperation(range); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto rank = generateRank(); + auto untrainable = generateUntrainableOperation(rank); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto reduce = generateReduce(); + auto untrainable = generateUntrainableOperation(reduce); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto reshape = generateReshape(); + auto untrainable = generateUntrainableOperation(reshape); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto resize_bilinear = generateResizeBilinear(); + auto untrainable = generateUntrainableOperation(resize_bilinear); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto resize_nearest_neighbor = generateResizeNearestNeighbor(); + auto untrainable = generateUntrainableOperation(resize_nearest_neighbor); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto reverse = generateReverse(); + auto untrainable = generateUntrainableOperation(reverse); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto rnn = generateRNN(); + auto untrainable = generateUntrainableOperation(rnn); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto select = generateSelect(); + auto untrainable = generateUntrainableOperation(select); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto shape = generateShape(); + auto untrainable = generateUntrainableOperation(shape); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto slice = generateSlice(); + auto untrainable = generateUntrainableOperation(slice); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto softmax = generateSoftmax(); + auto untrainable = generateUntrainableOperation(softmax); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto space_to_batchnd = generateSpaceToBatchND(); + auto untrainable = generateUntrainableOperation(space_to_batchnd); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto space_to_depth = generateSpaceToDepth(); + auto untrainable = generateUntrainableOperation(space_to_depth); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto split = generateSplit(); + auto untrainable = generateUntrainableOperation(split); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto splitv = generateSplitV(); + auto untrainable = generateUntrainableOperation(splitv); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto squared_diff = generateSquaredDifference(); + auto untrainable = generateUntrainableOperation(squared_diff); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto squeeze = generateSqueeze(); + auto untrainable = generateUntrainableOperation(squeeze); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto stateless_random_uniform = generateStatelessRandomUniform(); + auto untrainable = generateUntrainableOperation(stateless_random_uniform); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto strided_slice = generateStridedSlice(); + auto untrainable = generateUntrainableOperation(strided_slice); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto tile = generateTile(); + auto untrainable = generateUntrainableOperation(tile); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto topkv2 = generateTopKV2(); + auto untrainable = generateUntrainableOperation(topkv2); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto transpose = generateTranspose(); + auto untrainable = generateUntrainableOperation(transpose); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto transpose_conv = generateTransposeConv(); + auto untrainable = generateUntrainableOperation(transpose_conv); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto unpack = generateUnpack(); + auto untrainable = generateUntrainableOperation(unpack); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + + { + const auto while_op = generateWhile(); + auto untrainable = generateUntrainableOperation(while_op); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } +} diff --git a/runtime/onert/core/src/ir/verifier/Verifier.cc b/runtime/onert/core/src/ir/verifier/Verifier.cc index 09cbdcf2f..bcded0c68 100644 --- a/runtime/onert/core/src/ir/verifier/Verifier.cc +++ b/runtime/onert/core/src/ir/verifier/Verifier.cc @@ -21,6 +21,69 @@ #include "util/logging.h" +namespace +{ + +using namespace onert::ir; + +std::set<train::TrainingOperationIndex> +extractOperations(const train::UseDefChains &training_usedefs) +{ + // Extract TrainingOperations from training_usedefs + std::set<train::TrainingOperationIndex> operations; + for (const auto &pair : training_usedefs) + { + const auto &output = pair.first; + const auto &usedefs = pair.second; + const auto &defs = usedefs.getTrainingDefs(); + for (const auto &node_index : defs) + if (node_index.valid() && output.valid()) + operations.insert(node_index); + } + + return operations; +} + +std::unordered_map<train::TrainingOperationIndex, std::vector<train::TrainingOperandIndex>> +extractNodeInputs(const train::UseDefChains &training_usedefs) +{ + // Extract inputs of TrainingOperations from training_usedefs + std::unordered_map<train::TrainingOperationIndex, std::vector<train::TrainingOperandIndex>> + node_inputs; + for (const auto &pair : training_usedefs) + { + const auto &input = pair.first; + const auto &usedefs = pair.second; + const auto &uses = usedefs.getTrainingUses(); + for (const auto &node_index : uses) + if (node_index.valid() && input.valid()) + node_inputs[node_index].emplace_back(input); + } + + return node_inputs; +} + +std::unordered_map<train::TrainingOperationIndex, std::vector<train::TrainingOperandIndex>> +extractNodeOutputs(const train::UseDefChains &training_usedefs) +{ + // Extract outputs of TrainingOperations from training_usedefs + std::unordered_map<train::TrainingOperationIndex, std::vector<train::TrainingOperandIndex>> + node_outputs; + for (const auto &pair : training_usedefs) + { + const auto &output = pair.first; + const auto &usedefs = pair.second; + const auto &defs = usedefs.getTrainingDefs(); + for (const auto &node_index : defs) + if (node_index.valid() && output.valid()) + node_outputs[node_index].emplace_back(output); + } + + return node_outputs; +} + +} // namespace + namespace onert { namespace ir @@ -39,11 +102,11 @@ bool DAGChecker::verify(const Graph &graph) const noexcept OperationIndexMap<bool> visited; operations.iterate( - [&](const OperationIndex &index, const Operation &) { visited[index] = false; }); + [&](const OperationIndex &index, const IOperation &) { visited[index] = false; }); OperationIndexMap<bool> on_stack = visited; // Copy from visited - std::function<void(const OperationIndex &index, const Operation &)> dfs_recursive = - [&](const OperationIndex &index, const Operation &node) -> void { + std::function<void(const OperationIndex &index, const IOperation &)> dfs_recursive = + [&](const OperationIndex &index, const IOperation &node) -> void { if (on_stack[index]) cyclic = true; if (visited[index]) @@ -51,7 +114,7 @@ bool DAGChecker::verify(const Graph &graph) const noexcept visited[index] = true; on_stack[index] = true; - for (auto output : node.getOutputs() | Remove::DUPLICATED) + for (auto &&output : node.getOutputs() | Remove::DUPLICATED | Remove::UNDEFINED) { const auto &operand = graph.operands().at(output); for (const auto &use : operand.getUses()) @@ -68,16 +131,56 @@ bool DAGChecker::verify(const Graph &graph) const noexcept return !cyclic; } +// TODO Merge with the above DAGChecker::verify(const Graph &) +bool DAGChecker::verify(const train::UseDefChains &training_usedefs) const noexcept +{ + bool cyclic = false; + const auto operations = extractOperations(training_usedefs); + auto outputs_map = extractNodeOutputs(training_usedefs); + + std::unordered_map<train::TrainingOperationIndex, bool> visited; + for (const auto &node_index : operations) + visited[node_index] = false; + auto on_stack = visited; // Copy from visited + + std::function<void(const train::TrainingOperationIndex &index)> dfs_recursive = + [&](const train::TrainingOperationIndex &index) -> void { + if (on_stack[index]) + cyclic = true; + if (visited[index]) + return; + visited[index] = true; + on_stack[index] = true; + + auto &node_outputs = outputs_map[index]; + for (const auto &output : node_outputs) + { + const auto &uses = training_usedefs.at(output).getTrainingUses(); + for (const auto &use : uses) + { + dfs_recursive(use); + } + } + + on_stack[index] = false; + }; + + for (const auto &node_index : operations) + dfs_recursive(node_index); + + return !cyclic; +} + // // EdgeConsistencyVerifier // -bool EdgeConsistencyChecker::verify(const Graph &graph) const noexcept +bool EdgeChecker::verify(const Graph &graph) const noexcept { auto &operations = graph.operations(); uint32_t errors = 0; - operations.iterate([&](const OperationIndex &index, const Operation &node) { - for (auto operand_index : node.getInputs() | ir::Remove::UNDEFINED) + operations.iterate([&](const OperationIndex &index, const IOperation &node) { + for (auto &&operand_index : node.getInputs() | ir::Remove::UNDEFINED) { try { @@ -85,44 +188,117 @@ bool EdgeConsistencyChecker::verify(const Graph &graph) const noexcept bool operand_has_use = operand.getUses().contains(index); if (!operand_has_use) { - VERBOSE(EdgeConsistencyChecker) << "[ERROR] EDGE MISMATCH : Missing USE edge - Operand " - << operand_index << " to Operation " << index - << std::endl; + VERBOSE(EdgeChecker) << "[ERROR] EDGE MISMATCH : Missing USE edge - Operand " + << operand_index << " to Operation " << index << std::endl; errors += 1; } } catch (const std::out_of_range &e) { - VERBOSE(EdgeConsistencyChecker) - << "[ERROR] OPEARAND NOT FOUND : Operation " << index << " has Operand " - << operand_index << ", but the operand object is not present in the graph" << std::endl; + VERBOSE(EdgeChecker) << "[ERROR] OPEARAND NOT FOUND : Operation " << index + << " has Operand " << operand_index + << ", but the operand object is not present in the graph" << std::endl; errors += 1; } } - for (auto operand_index : node.getOutputs()) + for (auto &&operand_index : node.getOutputs() | ir::Remove::UNDEFINED) { try { auto &operand = graph.operands().at(operand_index); if (operand.getDef() != index) { - VERBOSE(EdgeConsistencyChecker) << "[ERROR] EDGE MISMATCH : Missing DEF edge - Operand" - << operand_index << " to Operation " << index - << std::endl; + VERBOSE(EdgeChecker) << "[ERROR] EDGE MISMATCH : Missing DEF edge - Operand" + << operand_index << " to Operation " << index << std::endl; errors += 1; } } catch (const std::out_of_range &e) { - VERBOSE(EdgeConsistencyChecker) - << "[ERROR] OPEARAND NOT FOUND : Operation " << index << " has Operand " - << operand_index << ", but the operand object is not present in the graph" << std::endl; + VERBOSE(EdgeChecker) << "[ERROR] OPEARAND NOT FOUND : Operation " << index + << " has Operand " << operand_index + << ", but the operand object is not present in the graph" << std::endl; errors += 1; } } }); - VERBOSE(EdgeConsistencyChecker) << "Total Number of errors : " << errors << std::endl; + VERBOSE(EdgeChecker) << "Total Number of errors : " << errors << std::endl; + + return errors == 0; +} + +bool InputOutputChecker::verify(const Graph &graph) const noexcept +{ + for (auto &&operand_ind : + (graph.getInputs() + graph.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED) + { + if (!graph.operands().exist(operand_ind)) + { + VERBOSE(InputOutputChecker) << "Input or Output tensor " << operand_ind << " does not exist."; + return false; + } + } + return true; +} + +// TODO Merge with the above EdgeChecker::verify(const Graph &) +bool EdgeChecker::verify(const train::UseDefChains &training_usedefs) const noexcept +{ + const auto operations = extractOperations(training_usedefs); + auto inputs_map = extractNodeInputs(training_usedefs); + auto outputs_map = extractNodeOutputs(training_usedefs); + uint32_t errors = 0; + for (const auto &index : operations) + { + const auto &node_inputs = inputs_map[index]; + for (const auto &operand_index : node_inputs) + { + try + { + const auto &uses = training_usedefs.at(operand_index).getTrainingUses(); + bool operand_has_use = (uses.find(index) != uses.end()); + if (!operand_has_use) + { + VERBOSE(EdgeChecker) << "[ERROR] EDGE MISMATCH : Missing USE edge - Operand " + << operand_index << " to Operation " << index << std::endl; + errors += 1; + } + } + catch (const std::out_of_range &e) + { + VERBOSE(EdgeChecker) << "[ERROR] OPEARAND NOT FOUND : Operation " << index + << " has Operand " << operand_index + << ", but the operand object is not present in the graph" << std::endl; + errors += 1; + } + } + + const auto &node_outputs = outputs_map[index]; + for (const auto &operand_index : node_outputs) + { + try + { + const auto &defs = training_usedefs.at(operand_index).getTrainingDefs(); + bool operand_has_def = (defs.find(index) != defs.end()); + if (!operand_has_def) + { + VERBOSE(EdgeChecker) << "[ERROR] EDGE MISMATCH : Missing DEF edge - Operand" + << operand_index << " to Operation " << index << std::endl; + errors += 1; + } + } + catch (const std::out_of_range &e) + { + VERBOSE(EdgeChecker) << "[ERROR] OPEARAND NOT FOUND : Operation " << index + << " has Operand " << operand_index + << ", but the operand object is not present in the graph" << std::endl; + errors += 1; + } + } + } + + VERBOSE(EdgeChecker) << "Total Number of errors : " << errors << std::endl; return errors == 0; } diff --git a/runtime/onert/core/src/ir/verifier/Verifier.h b/runtime/onert/core/src/ir/verifier/Verifier.h index 0c7b57b04..9f1dd8e60 100644 --- a/runtime/onert/core/src/ir/verifier/Verifier.h +++ b/runtime/onert/core/src/ir/verifier/Verifier.h @@ -17,6 +17,8 @@ #ifndef __ONERT_GRAPH_VERIFIER_VERIFIER_H__ #define __ONERT_GRAPH_VERIFIER_VERIFIER_H__ +#include "ir/train/UseDefChains.h" + namespace onert { namespace ir @@ -53,9 +55,20 @@ class DAGChecker : public IVerifier { public: bool verify(const Graph &graph) const noexcept override; + bool verify(const train::UseDefChains &training_defuses) const noexcept; }; -class EdgeConsistencyChecker : public IVerifier +class EdgeChecker : public IVerifier +{ +public: + bool verify(const Graph &graph) const noexcept override; + bool verify(const train::UseDefChains &training_defuses) const noexcept; +}; + +/** + * @brief Check model input and output operands are really exist in the graph + */ +class InputOutputChecker : public IVerifier { public: bool verify(const Graph &graph) const noexcept override; diff --git a/runtime/onert/core/src/ir/verifier/Verifier.test.cc b/runtime/onert/core/src/ir/verifier/Verifier.test.cc new file mode 100644 index 000000000..1ec71cd55 --- /dev/null +++ b/runtime/onert/core/src/ir/verifier/Verifier.test.cc @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Verifier.h" + +#include "../MockNode.h" + +#include "ir/Graph.h" + +#include <gtest/gtest.h> + +#include <memory> + +using IndexSet = onert::ir::OperandIndexSequence; +using Mock = onert_test::ir::SimpleMock; + +TEST(Verifier, dag_checker) +{ + onert::ir::Graph graph; + + onert::ir::Shape shape{3}; + onert::ir::TypeInfo type{onert::ir::DataType::INT32}; + + auto operand1 = graph.addOperand(shape, type); + auto operand2 = graph.addOperand(shape, type); + + graph.addInput(operand1); + graph.addOutput(operand2); + + graph.addOperation(std::make_unique<Mock>(IndexSet{operand1}, IndexSet{operand2})); + + onert::ir::verifier::DAGChecker verifier; + + ASSERT_TRUE(verifier.verify(graph)); +} + +TEST(Verifier, neg_edge_consistency_checker_1) +{ + onert::ir::Graph graph; + + onert::ir::Shape shape{3}; + onert::ir::TypeInfo type{onert::ir::DataType::INT32}; + + auto operand1 = graph.addOperand(shape, type); + auto operand2 = graph.addOperand(shape, type); + + graph.addInput(operand1); + graph.addOutput(operand2); + + auto mock_op = std::make_unique<Mock>(IndexSet{operand1}, IndexSet{operand2}); + auto op_ind = graph.addOperation(std::move(mock_op)); + + graph.operands().at(operand1).removeUse(op_ind); // Manipulate the operand alone + + onert::ir::verifier::EdgeChecker verifier; + ASSERT_FALSE(verifier.verify(graph)); +} + +TEST(Verifier, neg_edge_consistency_checker_2) +{ + onert::ir::Graph graph; + + onert::ir::Shape shape{3}; + onert::ir::TypeInfo type{onert::ir::DataType::INT32}; + + auto operand1 = graph.addOperand(shape, type); + auto operand2 = graph.addOperand(shape, type); + + graph.addInput(operand1); + graph.addOutput(operand2); + + auto mock_op = std::make_unique<Mock>(IndexSet{operand1}, IndexSet{operand2}); + auto mock_op_ptr = mock_op.get(); + auto op_ind = graph.addOperation(std::move(mock_op)); + + mock_op_ptr->setInputs({operand2}); // Manipulate the operation alone + + onert::ir::verifier::EdgeChecker verifier; + ASSERT_FALSE(verifier.verify(graph)); +} diff --git a/runtime/onert/core/src/loader/BaseLoader.h b/runtime/onert/core/src/loader/BaseLoader.h new file mode 100644 index 000000000..c3a50b0d8 --- /dev/null +++ b/runtime/onert/core/src/loader/BaseLoader.h @@ -0,0 +1,1794 @@ +/* + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_LOADER_BASE_LOADER_H__ +#define __ONERT_LOADER_BASE_LOADER_H__ + +#include "ir/Graph.h" +#include "ir/Shape.h" +#include "ir/Operations.Include.h" + +#include "flatbuffers/flexbuffers.h" + +#include <map> +#include <memory> +#include <fstream> +#include <limits> +#include <fcntl.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <unistd.h> +#include <util/logging.h> + +namespace onert +{ +namespace loader +{ + +template <typename LoaderDomain> class BaseLoader +{ +protected: + using Verifier = typename LoaderDomain::Verifier; + using ActivationFunctionType = typename LoaderDomain::ActivationFunctionType; + using Buffer = typename LoaderDomain::Buffer; + using BuiltinOperator = typename LoaderDomain::BuiltinOperator; + using CustomOptionsFormat = typename LoaderDomain::CustomOptionsFormat; + using Metadata = typename LoaderDomain::Metadata; + using Model = typename LoaderDomain::Model; + using Operator = typename LoaderDomain::Operator; + using Padding = typename LoaderDomain::Padding; + using Pool2DOptions = typename LoaderDomain::Pool2DOptions; + using SubGraph = typename LoaderDomain::SubGraph; + using Tensor = typename LoaderDomain::Tensor; + using TensorType = typename LoaderDomain::TensorType; + using DimensionType = typename LoaderDomain::DimensionType; + using SparseIndexVector = typename LoaderDomain::SparseIndexVector; + +protected: + bool isOptionalInputTensor(std::int32_t idx) { return idx == -1; } + virtual bool allowOptionalInputTensor(BuiltinOperator) = 0; + +public: + /** + * @brief Construct a new Loader object + * + * @param model reference to model + */ + explicit BaseLoader(std::unique_ptr<ir::Model> &model) + : _base{nullptr}, _pagesize(getpagesize()), _fd(-1), _model(model), _domain_model{nullptr} + { + _use_mmaped_data = util::getConfigBool(util::config::USE_MMAPED_DATA); + } + + /** + * @brief Load a model from file + * + * @param file_path + */ + void loadFromFile(const std::string &file_path); + /** + * @brief Load a model from a buffer + * + * @param buffer buffer pointer + * @param size buffer size + */ + void loadFromBuffer(uint8_t *buffer, size_t size); + +protected: + ~BaseLoader() = default; + void loadModel(); + + // Helper functions + ir::Activation convertActivation(ActivationFunctionType type); + ir::DataType tensorTypeToDataType(TensorType type); + ir::OperandIndex tensorIdxToOperandIdx(int32_t tensorIdx); + flexbuffers::Map getCustomOpAttrMap(const Operator *op); + + // Create operands form tflite::Tensor + ir::OperandIndex loadOperand(const Tensor *tensor, ir::Graph &subg); + void loadQuantization(const Tensor *tensor, ir::TypeInfo &typeInfo); + void loadSparsity(const Tensor *tensor, ir::TypeInfo &typeInfo); + void loadOperationIO(const Operator *op, ir::OperandIndexSequence &inputs, + ir::OperandIndexSequence &outputs); + // Create operations from Operator + void loadOperation(const Operator *op, ir::Graph &subg); + // Load Strides and Paddings from options to param + template <typename Param, typename OptionsType> + void loadStridesAndPaddings(Param ¶m, const OptionsType *options); + // Load Pool2D param + template <typename Param> void loadPool2DOptions(Param ¶m, const Pool2DOptions *options); + // Get BuiltinOperator + BuiltinOperator getBuiltinOperator(const Operator *op) + { + auto const builtin_opcode = _domain_model->operator_codes()->Get(op->opcode_index()); + auto builtin_op = builtin_opcode->builtin_code(); + if (builtin_op < BuiltinOperator::BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES) + builtin_op = static_cast<BuiltinOperator>(builtin_opcode->deprecated_builtin_code()); + + return builtin_op; + } + +private: + std::unique_ptr<ir::Data> loadMetadata(const uint32_t buffer_idx); + virtual std::unique_ptr<ir::Graph> loadSubgraph(const SubGraph *subg) = 0; + // Operations + template <typename OpIR, typename... Args> + const OpIR *loadOperationTo(const Operator *op, ir::Graph &subg, Args &&...args); + + void loadAddV2(const Operator *op, ir::Graph &subg); + void loadArgMinMax(const Operator *op, ir::Graph &subg, bool is_argmax); + void loadBatchMatMul(const Operator *op, ir::Graph &subg); + void loadBinaryArithmetic(const Operator *op, ir::Graph &subg, + ir::operation::BinaryArithmetic::ArithmeticType op_type); + void loadComparison(const Operator *op, ir::Graph &subg); + void loadConcatenation(const Operator *op, ir::Graph &subg); + void loadConv2D(const Operator *op, ir::Graph &subg); + void loadCustom(const Operator *op, ir::Graph &subg); + void loadDepthToSpace(const Operator *op, ir::Graph &subg); + void loadDepthwiseConv2D(const Operator *op, ir::Graph &subg); + void loadEinsum(const Operator *op, ir::Graph &subg); + void loadElementwiseActivation(const Operator *op, ir::Graph &subg, + ir::operation::ElementwiseActivation::Type op_type, + float alpha = 0.f, float beta = 0.f); + void loadElementwiseBinary(const Operator *op, ir::Graph &subg, + ir::operation::ElementwiseBinary::ElementwiseBinaryType op_type); + void loadElementwiseUnary(const Operator *op, ir::Graph &subg, + ir::operation::ElementwiseUnary::Type op_type); + void loadFC(const Operator *op, ir::Graph &subg); + void loadFusedBatchNorm(const Operator *op, ir::Graph &subg); + void loadGather(const Operator *op, ir::Graph &subg); + void loadIf(const Operator *op, ir::Graph &subg); + void loadLeakyRelu(const Operator *op, ir::Graph &subg); + void loadLogSoftmax(const Operator *op, ir::Graph &subg); + void loadDetectionPostProcess(const Operator *op, ir::Graph &subg); + void loadOneHot(const Operator *op, ir::Graph &subg); + void loadPack(const Operator *op, ir::Graph &subg); + void loadPool2D(const Operator *op, ir::Graph &subg, ir::operation::Pool2D::PoolType op_type); + void loadReduce(const Operator *op, ir::Graph &subg, + ir::operation::Reduce::ReduceType reduce_type); + void loadReduceAll(const Operator *op, ir::Graph &subg); + void loadReshape(const Operator *op, ir::Graph &subg); + void loadResizeBilinear(const Operator *op, ir::Graph &subg); + void loadResizeNearestNeighbor(const Operator *op, ir::Graph &subg); + void loadSoftmax(const Operator *op, ir::Graph &subg); + void loadSpaceToDepth(const Operator *op, ir::Graph &subg); + void loadSplit(const Operator *op, ir::Graph &subg); + void loadSplitV(const Operator *op, ir::Graph &subg); + void loadSqueeze(const Operator *op, ir::Graph &subg); + void loadStridedSlice(const Operator *op, ir::Graph &subg); + void loadTransposeConv(const Operator *op, ir::Graph &subg); + void loadUnidirectionalSequenceLSTM(const Operator *op, ir::Graph &subg); + void loadUnpack(const Operator *op, ir::Graph &subg); + void loadWhile(const Operator *op, ir::Graph &subg); + + void verifySubgraphIndex(int subg_index) + { + const auto num_subgraphs = _domain_model->subgraphs()->size(); + if (subg_index < 0 || subg_index >= static_cast<int32_t>(num_subgraphs)) + throw std::runtime_error{std::string{"Invalid subgraph index - "} + + std::to_string(subg_index)}; + } + +protected: + // Base address for mapped region for loading (if needed) + uint8_t *_base; + // Memory page size + int32_t _pagesize; + // loaded file description + int _fd; + // Reference to ir::model (to be loaded from _domain_model) + std::unique_ptr<ir::Model> &_model; + const Model *_domain_model; + // Maps Tensor indices to onert Operands. + std::vector<ir::OperandIndex> _tensor_to_operand; + std::unordered_map<ir::OperandIndex, std::string> _tensor_names; + // Verifier + std::unique_ptr<Verifier> _verifier; + // Boolean flag to use MMAPED_DATA + bool _use_mmaped_data = false; + + std::unordered_map<uint32_t /* Buffer Index in circle file */, std::shared_ptr<ir::Data>> + _buf_to_data; +}; + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::BaseLoader::loadFromFile(const std::string &file_path) +{ + _fd = open(file_path.c_str(), O_RDONLY); + if (_fd < 0) + { + throw std::runtime_error("Failed to open file " + file_path); + } + + struct stat file_stat; + if (fstat(_fd, &file_stat) != 0) + { + throw std::runtime_error("Fstat failed or file " + file_path + " is not a regular file"); + } + int size = file_stat.st_size; + + // Map model file into memory region + _base = static_cast<uint8_t *>(mmap(NULL, size, PROT_READ, MAP_PRIVATE, _fd, 0)); + if (_base == MAP_FAILED) + { + close(_fd); + throw std::runtime_error("mmap failed - " + std::string(strerror(errno))); + } + + _verifier = std::make_unique<Verifier>(reinterpret_cast<const std::uint8_t *>(_base), size); + + loadModel(); + munmap(_base, size); + + close(_fd); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::BaseLoader::loadFromBuffer(uint8_t *buffer, size_t size) +{ + _base = buffer; + _verifier = std::make_unique<Verifier>(reinterpret_cast<const std::uint8_t *>(_base), size); + loadModel(); +} + +template <typename LoaderDomain> +std::unique_ptr<ir::Data> +BaseLoader<LoaderDomain>::BaseLoader::loadMetadata(const uint32_t buffer_idx) +{ + assert(_domain_model != nullptr); + const auto *data = _domain_model->buffers()->Get(buffer_idx)->data(); + if (data == nullptr) + throw std::runtime_error("Metadata buffer is not found"); + + if (_fd == -1) // Model is from memory + { + return std::make_unique<ir::ExternalData>(data->data(), data->size()); + } + else // Model is loaded(mmap'd) from a file + { + size_t data_size = data->size(); + ptrdiff_t offset_start = data->data() - _base; + ptrdiff_t offset_end = offset_start + data_size; + + ptrdiff_t page_start = (offset_start / _pagesize) * _pagesize; + size_t mapping_size = offset_end - page_start; + + // Since metadata is not access often in inference/training time, always use mmaped-data + // Ref : https://github.com/Samsung/ONE/issues/3961#issuecomment-681750231 + return std::make_unique<ir::MMapedData>(_fd, page_start, mapping_size, offset_start, data_size); + } +} + +template <typename LoaderDomain> +ir::Activation +BaseLoader<LoaderDomain>::BaseLoader::convertActivation(const ActivationFunctionType type) +{ + switch (type) + { + case ActivationFunctionType::ActivationFunctionType_NONE: + return ir::Activation::NONE; + case ActivationFunctionType::ActivationFunctionType_RELU: + return ir::Activation::RELU; + case ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1: + return ir::Activation::RELU1; + case ActivationFunctionType::ActivationFunctionType_RELU6: + return ir::Activation::RELU6; + case ActivationFunctionType::ActivationFunctionType_TANH: + return ir::Activation::TANH; + default: + throw std::runtime_error(std::string("Unsupported or invalid activation type: ") + + std::to_string(static_cast<int>(type))); + } +} + +template <typename LoaderDomain> +ir::DataType BaseLoader<LoaderDomain>::BaseLoader::tensorTypeToDataType(const TensorType type) +{ + switch (type) + { + case TensorType::TensorType_FLOAT32: + return ir::DataType::FLOAT32; + case TensorType::TensorType_FLOAT16: + return ir::DataType::FLOAT16; + case TensorType::TensorType_INT32: + return ir::DataType::INT32; + case TensorType::TensorType_UINT8: + return ir::DataType::QUANT_UINT8_ASYMM; + case TensorType::TensorType_INT64: + return ir::DataType::INT64; + // case TensorType::TensorType_STRING: + case TensorType::TensorType_BOOL: + return ir::DataType::BOOL8; + case TensorType::TensorType_INT16: + return ir::DataType::QUANT_INT16_ASYMM; + // case TensorType::TensorType_COMPLEX64 + case TensorType::TensorType_INT8: + return ir::DataType::QUANT_INT8_ASYMM; + // case TensorType::TensorType_FLOAT64 + case TensorType::TensorType_UINT32: + return ir::DataType::UINT32; + default: + throw std::runtime_error( + std::string("Unsupported tensor type: ").append(EnumNameTensorType(type))); + } +} + +template <typename LoaderDomain> +ir::OperandIndex BaseLoader<LoaderDomain>::BaseLoader::tensorIdxToOperandIdx(int32_t tensorIdx) +{ + return isOptionalInputTensor(tensorIdx) ? ir::OperandIndex() : _tensor_to_operand[tensorIdx]; +} + +template <typename LoaderDomain> +flexbuffers::Map BaseLoader<LoaderDomain>::BaseLoader::getCustomOpAttrMap(const Operator *op) +{ + size_t custom_op_data_size = op->custom_options()->size(); + auto custom_op_data = op->custom_options()->Data(); + auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size); + return data_root.AsMap(); +} + +/* Copy is copied from tensorflow lite */ +template <typename T> bool Copy(const T *data_ptr, std::vector<uint16_t> &arr) +{ + if (data_ptr->values() == nullptr) + { + return false; + } + + int size = data_ptr->values()->size(); + arr.reserve(size); + for (int i = 0; i < size; i++) + { + arr.emplace_back(static_cast<uint16_t>(data_ptr->values()->Get(i))); + } + return true; +} + +template <typename LoaderDomain> +ir::OperandIndex BaseLoader<LoaderDomain>::loadOperand(const Tensor *tensor, ir::Graph &subg) +{ + ir::Shape shape; + // Shape + const auto *tensor_shape = tensor->shape(); + if (tensor_shape != nullptr) + { + for (const auto &dim : *tensor_shape) + { + shape.append(dim); + } + } + + // Note for tensor->shape_signature() + // We don't handle shape signature + // How we handle: + // If shape_signature[k] == -1, we will use tensor->shape()[k] == 1 + // If app wants to change the input shape, call nnfw_apply_input_tensorinfo() can + // be used. + + // TypeInfo + ir::TypeInfo type_info(tensorTypeToDataType(tensor->type())); + loadQuantization(tensor, type_info); + loadSparsity(tensor, type_info); + + // Create operand + const auto operand_index = subg.addOperand(shape, type_info); + + // Constant tensors are indicated by non-empty data. + const auto *data = _domain_model->buffers()->Get(tensor->buffer())->data(); + if (data != nullptr) + { + using std::ptrdiff_t; + std::shared_ptr<ir::Data> data_obj; + + if (_fd == -1) // Model is from memory + { + data_obj = std::make_shared<ir::ExternalData>(data->data(), data->size()); + } + else // Model is loaded(mmap'd) from a file + { + size_t data_size = data->size(); + ptrdiff_t unaligned_offset_start = data->data() - _base; + ptrdiff_t offset_end = unaligned_offset_start + data_size; + + // Calculated aligned offset from base address of mapped region + // munmap accepts memory address which is a multiple of the pagesize + ptrdiff_t aligned_offset_start = (unaligned_offset_start / _pagesize) * _pagesize; + size_t mmap_size = offset_end - aligned_offset_start; + + uint32_t buf_idx = tensor->buffer(); + auto buffer_found = _buf_to_data.find(buf_idx); + + if (buffer_found != _buf_to_data.end()) + { + // Another tensor points this buffer and its matching Data(either CachedData or MMapedData) + // was already created. Let's reuse the Data + data_obj = buffer_found->second; + } + else if (_use_mmaped_data) + { + data_obj = std::make_shared<ir::MMapedData>(_fd, aligned_offset_start, mmap_size, + unaligned_offset_start, data_size); + _buf_to_data[buf_idx] = data_obj; + } + else + { + size_t offset = unaligned_offset_start - aligned_offset_start; + uint8_t *mmap_base = static_cast<uint8_t *>( + mmap(NULL, mmap_size, PROT_READ, MAP_PRIVATE, _fd, aligned_offset_start)); + + data_obj = std::make_shared<ir::CachedData>(mmap_base + offset, data_size); + _buf_to_data[buf_idx] = data_obj; + + munmap(mmap_base, mmap_size); + } + } + subg.setOperandValue(operand_index, std::move(data_obj)); + } + + _tensor_names.emplace(operand_index, tensor->name()->str()); + + // Variable + if (tensor->is_variable()) + { + if (data != nullptr) + throw std::runtime_error("Variable tensor with buffer is not supported!"); + + subg.operands().at(operand_index).info().setAsVariable(); + } + + return operand_index; +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadQuantization(const Tensor *tensor, ir::TypeInfo &typeInfo) +{ + auto q_params = tensor->quantization(); + if (q_params == nullptr || q_params->scale() == nullptr || q_params->scale()->size() == 0) + { + typeInfo.quantization(0., 0); + return; + } + if (q_params->zero_point() == nullptr) + { + throw std::runtime_error("Quantization params: scale is not null, but zero_point is null."); + } + const size_t num_scales = q_params->scale()->size(); + if (num_scales != q_params->zero_point()->size()) + { + throw std::runtime_error("Quantization params: scale size != zero_point size"); + } + std::vector<float> scales; + std::vector<int32_t> zero_points; + scales.resize(num_scales); + zero_points.resize(num_scales); + for (size_t i = 0; i < num_scales; ++i) + { + scales[i] = q_params->scale()->Get(i); + // zero_point is defined as long (i64) in schema while TypeInfo's zero_point is int32_t. + // int64_t is used instead of long because long is 4 byte in most 32bit architecture. + int64_t zero_point = q_params->zero_point()->Get(i); + if (zero_point < std::numeric_limits<int32_t>::min() || + zero_point > std::numeric_limits<int32_t>::max()) + throw std::runtime_error("Zero_point is out of int32 range."); + zero_points[i] = static_cast<int32_t>(zero_point); + } + auto details = q_params->details_as_CustomQuantization(); + if (details != nullptr) + throw std::runtime_error("Custom Quantization is not supported"); + typeInfo.quantization(std::move(scales), std::move(zero_points)); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadSparsity(const Tensor *tensor, ir::TypeInfo &typeInfo) +{ + auto src_sparsity = tensor->sparsity(); + if (src_sparsity != nullptr) + { + std::vector<uint16_t> w1_segments; + std::vector<uint16_t> w1_indices; + // check traversal_order + if (src_sparsity->traversal_order()) + { + const int traversal_order_size = src_sparsity->traversal_order()->size(); + for (int i = 0; i < traversal_order_size; ++i) + { + if (i != src_sparsity->traversal_order()->Get(i)) + throw std::runtime_error("traversal_order [0, 1, ..., n-1] is only supported."); + } + } + // check block_map + int block_rank = 0; + if (src_sparsity->block_map()) + { + block_rank = src_sparsity->block_map()->size(); + for (int i = 0; i < block_rank; ++i) + { + if (i != src_sparsity->block_map()->Get(i)) + throw std::runtime_error("block_map [0, 1, ..., n-1] is only supported."); + } + } + // load metadata + const auto dim_metadata_size = src_sparsity->dim_metadata()->size(); + const auto dense_rank = tensor->shape() ? tensor->shape()->size() : 0; + if (dense_rank + block_rank != dim_metadata_size) + throw std::runtime_error("sparsity dim_metadata length is wrong."); + bool random_sparsity = dim_metadata_size == 2 && block_rank == 0; + bool block2D_sparsity = dim_metadata_size == 4 && block_rank == 2; + if (dim_metadata_size != !random_sparsity && !block2D_sparsity) + throw std::runtime_error( + "sparsity is supported only for 2D tensor with random or 16x1 block sparsity."); + + const auto *src_metadata = src_sparsity->dim_metadata()->Get(0); + if (src_metadata->format() != DimensionType::DimensionType_DENSE) + throw std::runtime_error("sparse tensor dim[0] is not DENSE"); + src_metadata = src_sparsity->dim_metadata()->Get(1); + if (src_metadata->format() != DimensionType::DimensionType_SPARSE_CSR) + throw std::runtime_error("sparse tensor dim[0] is not SPARSE_CSR"); + auto ParseSparseIndexVector = [src_metadata, &w1_segments, &w1_indices]() { + if (src_metadata->array_segments() == nullptr || src_metadata->array_indices() == nullptr) + return false; + bool status = true; + /* `onert` inernally uses uint16 type regardless of the value of + the array_segments_type and array_indices_type */ + switch (src_metadata->array_segments_type()) + { + case SparseIndexVector::SparseIndexVector_Int32Vector: + throw std::runtime_error("sparse tensor with int32 segment type is not supported"); + case SparseIndexVector::SparseIndexVector_Uint16Vector: + status = Copy(src_metadata->array_segments_as_Uint16Vector(), w1_segments); + break; + case SparseIndexVector::SparseIndexVector_Uint8Vector: + status = Copy(src_metadata->array_segments_as_Uint8Vector(), w1_segments); + break; + default: + return false; + } + if (status != true) + return false; + switch (src_metadata->array_indices_type()) + { + case SparseIndexVector::SparseIndexVector_Int32Vector: + throw std::runtime_error("sparse tensor with int32 indices type is not supported"); + case SparseIndexVector::SparseIndexVector_Uint16Vector: + return Copy(src_metadata->array_indices_as_Uint16Vector(), w1_indices); + case SparseIndexVector::SparseIndexVector_Uint8Vector: + return Copy(src_metadata->array_indices_as_Uint8Vector(), w1_indices); + default: + break; + } + return false; + }; + if (ParseSparseIndexVector() == false) + throw std::runtime_error("Error during parsing sparsity index information"); + // Get block size + std::vector<int32_t> block_size; + for (int i = 0; i < block_rank; ++i) + { + auto block_metadata = src_sparsity->dim_metadata()->Get(dense_rank + i); + if (block_metadata->format() != DimensionType::DimensionType_DENSE) + throw std::runtime_error("block dimension must be DENSE."); + block_size.push_back(block_metadata->dense_size()); + } + typeInfo.sparsity(std::make_shared<ir::Sparsity>(std::move(w1_segments), std::move(w1_indices), + std::move(block_size))); + } +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadOperationIO(const Operator *op, ir::OperandIndexSequence &inputs, + ir::OperandIndexSequence &outputs) +{ + for (const std::int32_t idx : *op->inputs()) + { + // Optional tensors are not supported yet except for FULLY_CONNECTED and BCQ_FULLY_CONNECTED + auto check_optional_input = [&]() { + auto builtin_code = getBuiltinOperator(op); + if (isOptionalInputTensor(idx) && !allowOptionalInputTensor(builtin_code)) + throw std::runtime_error( + std::string("loader doesn't support optional input tensor yet for ") + .append(EnumNameBuiltinOperator(builtin_code))); + }; + check_optional_input(); + inputs.append(tensorIdxToOperandIdx(idx)); + } + + for (const std::int32_t idx : *op->outputs()) + { + outputs.append(tensorIdxToOperandIdx(idx)); + } +} + +template <typename LoaderDomain> +template <typename Param, typename OptionsType> +void BaseLoader<LoaderDomain>::loadStridesAndPaddings(Param ¶m, const OptionsType *options) +{ + // Strides + param.stride.vertical = options->stride_h(); + param.stride.horizontal = options->stride_w(); + // Paddings + switch (options->padding()) + { + case Padding::Padding_SAME: + param.padding.type = ir::PaddingType::SAME; + break; + case Padding::Padding_VALID: + param.padding.type = ir::PaddingType::VALID; + break; + default: + throw std::runtime_error{"Invalid padding type"}; + } + // param paddings indexes unused +} + +template <typename LoaderDomain> +template <typename Param> +void BaseLoader<LoaderDomain>::loadPool2DOptions(Param ¶m, const Pool2DOptions *options) +{ + // Strides and Paddings + if (options->stride_h() <= 0 || options->stride_w() <= 0) + throw std::runtime_error{"Invalid stride vertical or horizontal - both must be bigger than 0"}; + loadStridesAndPaddings(param, options); + // Filter width and height + // Strides + if (options->filter_width() <= 0 || options->filter_height() <= 0) + throw std::runtime_error{"Invalid filter width or height - both must be bigger than 0"}; + param.kw = options->filter_width(); + param.kh = options->filter_height(); + // Activation + param.activation = convertActivation(options->fused_activation_function()); +} + +template <typename LoaderDomain> +template <typename OpIR, typename... Args> +const OpIR *BaseLoader<LoaderDomain>::loadOperationTo(const Operator *op, ir::Graph &subg, + Args &&...args) +{ + static_assert(sizeof...(args) <= 1, "You can't have more than 1 arguments!"); + ir::OperandIndexSequence inputs; + ir::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + std::unique_ptr<OpIR> new_op(new OpIR(inputs, outputs, std::forward<Args>(args)...)); + auto ret = new_op.get(); + subg.addOperation(std::move(new_op)); + + return ret; +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadConv2D(const Operator *op, ir::Graph &subg) +{ + ir::operation::Conv2D::Param param; + const auto *options = op->builtin_options_as_Conv2DOptions(); + param.activation = convertActivation(options->fused_activation_function()); + loadStridesAndPaddings(param, options); + param.dilation.width_factor = options->dilation_w_factor(); + param.dilation.height_factor = options->dilation_h_factor(); + + const auto conv = loadOperationTo<ir::operation::Conv2D>(op, subg, param); + + // TFLite support old hybrid quantization (float input/output, uint8 kernel) + // but it interprets weight type as init8 internally + const auto &input_operand = + subg.operands().at(conv->getInputs().at(ir::operation::Conv2D::INPUT)); + auto &weights_operand = subg.operands().at(conv->getInputs().at(ir::operation::Conv2D::KERNEL)); + if (input_operand.typeInfo().type() == ir::DataType::FLOAT32 && + ((weights_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM) || + weights_operand.typeInfo().type() == ir::DataType::QUANT_INT8_ASYMM)) + { + weights_operand.type(ir::DataType::QUANT_INT8_SYMM); + } +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadDepthwiseConv2D(const Operator *op, ir::Graph &subg) +{ + ir::operation::DepthwiseConv2D::Param param; + const auto *options = op->builtin_options_as_DepthwiseConv2DOptions(); + param.activation = convertActivation(options->fused_activation_function()); + loadStridesAndPaddings(param, options); + param.multiplier = options->depth_multiplier(); + // Dilation h/w factor unused + param.dilation.width_factor = options->dilation_w_factor(); + param.dilation.height_factor = options->dilation_h_factor(); + + const auto dconv = loadOperationTo<ir::operation::DepthwiseConv2D>(op, subg, param); + + // TFLite does not support old hybrid quantization (float input/output, uint8 kernel) + // for depthwise convolution. + // But for consistency with Conv2D and FC, we interpret weight type as init8 internally + const auto &input_operand = + subg.operands().at(dconv->getInputs().at(ir::operation::DepthwiseConv2D::INPUT)); + auto &weights_operand = + subg.operands().at(dconv->getInputs().at(ir::operation::DepthwiseConv2D::KERNEL)); + if (input_operand.typeInfo().type() == ir::DataType::FLOAT32 && + ((weights_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM) || + weights_operand.typeInfo().type() == ir::DataType::QUANT_INT8_ASYMM)) + { + weights_operand.type(ir::DataType::QUANT_INT8_SYMM); + } +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadTransposeConv(const Operator *op, ir::Graph &subg) +{ + ir::operation::TransposeConv::Param param; + const auto *options = op->builtin_options_as_TransposeConvOptions(); + loadStridesAndPaddings(param, options); + + loadOperationTo<ir::operation::TransposeConv>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadPool2D(const Operator *op, ir::Graph &subg, + ir::operation::Pool2D::PoolType op_type) +{ + ir::operation::Pool2D::Param param; + param.op_type = op_type; + const auto *options = op->builtin_options_as_Pool2DOptions(); + + loadPool2DOptions(param, options); + + loadOperationTo<ir::operation::Pool2D>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadReshape(const Operator *op, ir::Graph &subg) +{ + ir::operation::Reshape::Param param{}; + const auto *options = op->builtin_options_as_ReshapeOptions(); + if (options != nullptr) + { + const auto *new_shape = options->new_shape(); + if (new_shape) + { + for (uint i = 0; i < new_shape->size(); ++i) + { + param.new_shape.push_back(new_shape->Get(i)); + } + } + } + + loadOperationTo<ir::operation::Reshape>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadSoftmax(const Operator *op, ir::Graph &subg) +{ + ir::operation::Softmax::Param param; + const auto *options = op->builtin_options_as_SoftmaxOptions(); + // Beta + param.beta = options->beta(); + + loadOperationTo<ir::operation::Softmax>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadConcatenation(const Operator *op, ir::Graph &subg) +{ + ir::operation::Concat::Param param; + const auto *options = op->builtin_options_as_ConcatenationOptions(); + // Axis + param.axis = options->axis(); + // activation unused + + loadOperationTo<ir::operation::Concat>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadFC(const Operator *op, ir::Graph &subg) +{ + ir::operation::FullyConnected::Param param; + const auto *options = op->builtin_options_as_FullyConnectedOptions(); + + param.activation = convertActivation(options->fused_activation_function()); + param.weights_format = static_cast<ir::FullyConnectedWeightsFormat>(options->weights_format()); + + const auto fc = loadOperationTo<ir::operation::FullyConnected>(op, subg, param); + + // TFLite supports old hybrid quantization (float input/output, uint8 kernel) + // but it interprets weight type as init8 internally + const auto &input_operand = + subg.operands().at(fc->getInputs().at(ir::operation::FullyConnected::INPUT)); + auto &weights_operand = + subg.operands().at(fc->getInputs().at(ir::operation::FullyConnected::WEIGHT)); + if (input_operand.typeInfo().type() == ir::DataType::FLOAT32 && + ((weights_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM) || + weights_operand.typeInfo().type() == ir::DataType::QUANT_INT8_ASYMM)) + { + weights_operand.type(ir::DataType::QUANT_INT8_SYMM); + } +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadAddV2(const Operator *op, ir::Graph &subg) +{ + ir::operation::BinaryArithmetic::Param param; + param.arithmetic_type = ir::operation::BinaryArithmetic::ArithmeticType::ADD; + + if (op->custom_options() == nullptr) + { + param.activation = ir::Activation::NONE; + } + else + { + const auto attr_map = getCustomOpAttrMap(op); + const auto fused_activation_func = static_cast<typename LoaderDomain::ActivationFunctionType>( + attr_map["fused_activation_function"].AsInt8()); + param.activation = convertActivation(fused_activation_func); + } + + loadOperationTo<ir::operation::BinaryArithmetic>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadDepthToSpace(const Operator *op, ir::Graph &subg) +{ + ir::operation::DepthToSpace::Param param; + const auto *options = op->builtin_options_as_DepthToSpaceOptions(); + param.block_size = options->block_size(); + + loadOperationTo<ir::operation::DepthToSpace>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadBinaryArithmetic( + const Operator *op, ir::Graph &subg, ir::operation::BinaryArithmetic::ArithmeticType op_type) +{ + ir::operation::BinaryArithmetic::Param param; + param.arithmetic_type = op_type; + switch (op_type) + { + case ir::operation::BinaryArithmetic::ArithmeticType::ADD: + { + const auto *add_options = op->builtin_options_as_AddOptions(); + param.activation = convertActivation(add_options->fused_activation_function()); + break; + } + case ir::operation::BinaryArithmetic::ArithmeticType::SUB: + { + const auto *sub_options = op->builtin_options_as_SubOptions(); + param.activation = convertActivation(sub_options->fused_activation_function()); + break; + } + case ir::operation::BinaryArithmetic::ArithmeticType::MUL: + { + const auto *mul_options = op->builtin_options_as_MulOptions(); + param.activation = convertActivation(mul_options->fused_activation_function()); + break; + } + case ir::operation::BinaryArithmetic::ArithmeticType::DIV: + { + const auto *div_options = op->builtin_options_as_DivOptions(); + param.activation = convertActivation(div_options->fused_activation_function()); + break; + } + default: + assert(false && + "The function 'loadBinaryArithmetic' supports only BinaryArithmetic operations"); + break; + } + + loadOperationTo<ir::operation::BinaryArithmetic>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadPack(const Operator *op, ir::Graph &subg) +{ + ir::operation::Pack::Param param; + const auto *options = op->builtin_options_as_PackOptions(); + param.num = options->values_count(); + param.axis = options->axis(); + + loadOperationTo<ir::operation::Pack>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadElementwiseActivation( + const Operator *op, ir::Graph &subg, ir::operation::ElementwiseActivation::Type op_type, + float alpha, float beta) +{ + ir::operation::ElementwiseActivation::Param param; + param.op_type = op_type; + param.alpha = alpha; + param.beta = beta; + + loadOperationTo<ir::operation::ElementwiseActivation>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadResizeBilinear(const Operator *op, ir::Graph &subg) +{ + ir::operation::ResizeBilinear::Param param; + // heigh_out and width_out is used on NNAPI only + assert(op->inputs()->size() == 2); + param.height_out = 0; + param.width_out = 0; + param.align_corners = op->builtin_options_as_ResizeBilinearOptions()->align_corners(); + param.half_pixel_centers = op->builtin_options_as_ResizeBilinearOptions()->half_pixel_centers(); + + loadOperationTo<ir::operation::ResizeBilinear>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadResizeNearestNeighbor(const Operator *op, ir::Graph &subg) +{ + ir::operation::ResizeNearestNeighbor::Param param; + // heigh_out and width_out is used on NNAPI only + assert(op->inputs()->size() == 2); + param.height_out = 0; + param.width_out = 0; + param.align_corners = op->builtin_options_as_ResizeNearestNeighborOptions()->align_corners(); + + loadOperationTo<ir::operation::ResizeNearestNeighbor>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadReduce(const Operator *op, ir::Graph &subg, + ir::operation::Reduce::ReduceType reduce_type) +{ + ir::operation::Reduce::Param param; + param.reduce_type = reduce_type; + param.keep_dims = op->builtin_options_as_ReducerOptions()->keep_dims(); + + loadOperationTo<ir::operation::Reduce>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadReduceAll(const Operator *op, ir::Graph &subg) +{ + ir::operation::Reduce::Param param; + param.reduce_type = ir::operation::Reduce::ReduceType::ALL; + if (op->custom_options() == nullptr) + { + param.keep_dims = false; + } + else + { + const auto attr_map = getCustomOpAttrMap(op); + param.keep_dims = attr_map["keep_dims"].AsBool(); + } + + loadOperationTo<ir::operation::Reduce>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadElementwiseBinary( + const Operator *op, ir::Graph &subg, + ir::operation::ElementwiseBinary::ElementwiseBinaryType op_type) +{ + ir::operation::ElementwiseBinary::Param param; + param.op_type = op_type; + + loadOperationTo<ir::operation::ElementwiseBinary>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadElementwiseUnary(const Operator *op, ir::Graph &subg, + ir::operation::ElementwiseUnary::Type op_type) +{ + ir::operation::ElementwiseUnary::Param param; + param.op_type = op_type; + + const auto eu = loadOperationTo<ir::operation::ElementwiseUnary>(op, subg, param); + if (op_type == ir::operation::ElementwiseUnary::Type::CAST) + { + auto qasymm8ToUint8 = [](ir::Operand &operand) { + if (operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM) + { + operand.type(ir::DataType::UINT8); + } + }; + qasymm8ToUint8( + subg.operands().at(eu->getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT))); + qasymm8ToUint8(subg.operands().at(eu->getOutputs().at(0))); + } +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadGather(const Operator *op, ir::Graph &subg) +{ + ir::operation::Gather::Param param; + param.axis = op->builtin_options_as_GatherOptions()->axis(); + + loadOperationTo<ir::operation::Gather>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadDetectionPostProcess(const Operator *op, ir::Graph &subg) +{ + const auto &m = getCustomOpAttrMap(op); + + ir::operation::DetectionPostProcess::Param param; + + param.max_detections = m["max_detections"].AsInt32(); + + // TODO fixme + param.max_classes_per_detection = m["max_classes_per_detection"].AsInt32(); + if (m["detections_per_class"].IsNull()) + param.max_boxes_per_class = 100; + else + param.max_boxes_per_class = m["detections_per_class"].AsInt32(); + + if (m["use_regular_nms"].IsNull()) + param.do_fast_eval = true; + else + param.do_fast_eval = !m["use_regular_nms"].AsBool(); + + param.score_threshold = m["nms_score_threshold"].AsFloat(); + param.iou_threshold = m["nms_iou_threshold"].AsFloat(); + + // TODO add num classes support + param.num_classes = m["num_classes"].AsInt32(); + + param.scale.y_scale = m["y_scale"].AsFloat(); + param.scale.x_scale = m["x_scale"].AsFloat(); + param.scale.h_scale = m["h_scale"].AsFloat(); + param.scale.w_scale = m["w_scale"].AsFloat(); + + // TODO depends on input model framework + param.center_size_boxes = true; + + loadOperationTo<ir::operation::DetectionPostProcess>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadBatchMatMul(const Operator *op, ir::Graph &subg) +{ + ir::operation::BatchMatMul::Param param; + + const auto builtin_op = getBuiltinOperator(op); + + switch (builtin_op) + { + case BuiltinOperator::BuiltinOperator_BATCH_MATMUL: + // Handled on each loader: different option name + // Circle: adjoint_lhs, adjoint_rhs + // TFLite: adj_x, adj_y + throw std::runtime_error( + std::string("Cannot handle here: ").append(EnumNameBuiltinOperator(builtin_op)) + " as " + + EnumNameBuiltinOperator(BuiltinOperator::BuiltinOperator_BATCH_MATMUL)); + case BuiltinOperator::BuiltinOperator_CUSTOM: + if (op->custom_options() == nullptr) + { + param.adj_x = false; + param.adj_y = false; + } + else + { + const auto attr_map = getCustomOpAttrMap(op); + param.adj_x = attr_map["adj_x"].AsBool(); + param.adj_y = attr_map["adj_y"].AsBool(); + } + break; + default: + throw std::runtime_error( + std::string("Wrong loaded operation: ").append(EnumNameBuiltinOperator(builtin_op)) + + " as " + EnumNameBuiltinOperator(BuiltinOperator::BuiltinOperator_BATCH_MATMUL)); + } + + loadOperationTo<ir::operation::BatchMatMul>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadSpaceToDepth(const Operator *op, ir::Graph &subg) +{ + ir::operation::SpaceToDepth::Param param; + const auto *options = op->builtin_options_as_SpaceToDepthOptions(); + param.block_size = options->block_size(); + + loadOperationTo<ir::operation::SpaceToDepth>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadCustom(const Operator *op, ir::Graph &subg) +{ + ir::OperandIndexSequence inputs; + ir::OperandIndexSequence outputs; + + assert(op->custom_options_format() == CustomOptionsFormat::CustomOptionsFormat_FLEXBUFFERS && + "Unsupported custom operation options format"); + + auto *op_code = _domain_model->operator_codes()->Get(op->opcode_index()); + auto custom_op_name = op_code->custom_code()->str(); + + enum class BuiltinOP + { + AddV2, + ReduceAll, + MatrixBandPart, + BatchMatMul, + Einsum, + BroadcastTo, + FusedBatchNorm, + StatelessRandomUniform, + Erf, + DetectionPostProcess + }; + + // Mapping from custom op name string to BuiltinOP enum + std::map<std::string, BuiltinOP> builtin_map = { + {"AddV2", BuiltinOP::AddV2}, + {"All", BuiltinOP::ReduceAll}, + {"MatrixBandPart", BuiltinOP::MatrixBandPart}, + {"BatchMatMulV2", BuiltinOP::BatchMatMul}, + {"Einsum", BuiltinOP::Einsum}, + {"FusedBatchNormV3", BuiltinOP::FusedBatchNorm}, + {"BroadcastTo", BuiltinOP::BroadcastTo}, + {"StatelessRandomUniform", BuiltinOP::StatelessRandomUniform}, + {"Erf", BuiltinOP::Erf}, + {"TFLite_Detection_PostProcess", BuiltinOP::DetectionPostProcess}, + }; + + try + { + // Throw out_of_range if it is unknown custom op + auto custom_op_id = builtin_map.at(custom_op_name); + switch (custom_op_id) + { + case BuiltinOP::AddV2: + loadAddV2(op, subg); + break; + case BuiltinOP::ReduceAll: + loadReduceAll(op, subg); + break; + case BuiltinOP::MatrixBandPart: + loadOperationTo<ir::operation::MatrixBandPart>(op, subg); + break; + case BuiltinOP::BatchMatMul: + loadBatchMatMul(op, subg); + break; + case BuiltinOP::Einsum: + loadEinsum(op, subg); + break; + case BuiltinOP::BroadcastTo: + loadOperationTo<ir::operation::BroadcastTo>(op, subg); + break; + case BuiltinOP::FusedBatchNorm: + loadFusedBatchNorm(op, subg); + break; + case BuiltinOP::StatelessRandomUniform: + loadOperationTo<ir::operation::StatelessRandomUniform>(op, subg); + break; + case BuiltinOP::Erf: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ERF); + break; + case BuiltinOP::DetectionPostProcess: + loadDetectionPostProcess(op, subg); + break; + default: + throw std::runtime_error{ + "Loader: Custom OP map is defined but operation loader function is not defined"}; + } + + return; + } + catch (...) + { + loadOperationIO(op, inputs, outputs); + + auto constraint = ir::OperandConstraint::createExact(inputs.size()); + + size_t custom_op_data_size = op->custom_options()->size(); + auto custom_op_data = new char[custom_op_data_size]; + std::copy(op->custom_options()->begin(), op->custom_options()->end(), custom_op_data); + + ir::operation::Custom::Userdata userdata{}; + userdata.data = custom_op_data; + userdata.size = custom_op_data_size; + + auto new_op = std::make_unique<ir::operation::Custom>(constraint, inputs, outputs, + custom_op_name, userdata); + + subg.addOperation(std::move(new_op)); + } +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadSqueeze(const Operator *op, ir::Graph &subg) +{ + ir::operation::Squeeze::Param param; + const auto *options = op->builtin_options_as_SqueezeOptions(); + const auto *dims = options->squeeze_dims(); + if (dims) + { + if (dims->size() > sizeof(param.dims) / sizeof(param.dims[0])) + throw std::runtime_error("Squeeze: 'param.ndims' is out of range."); + param.ndim = dims->size(); + for (int i = 0; i < param.ndim; ++i) + param.dims[i] = dims->Get(i); + } + else + param.ndim = 0; + + loadOperationTo<ir::operation::Squeeze>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadSplit(const Operator *op, ir::Graph &subg) +{ + ir::operation::Split::Param param; + const auto *options = op->builtin_options_as_SplitOptions(); + param.num_splits = options->num_splits(); + + loadOperationTo<ir::operation::Split>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadSplitV(const Operator *op, ir::Graph &subg) +{ + ir::operation::SplitV::Param param; + const auto *options = op->builtin_options_as_SplitVOptions(); + param.num_splits = options->num_splits(); + + loadOperationTo<ir::operation::SplitV>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadStridedSlice(const Operator *op, ir::Graph &subg) +{ + ir::operation::StridedSlice::Param param; + const auto *options = op->builtin_options_as_StridedSliceOptions(); + param.begin_mask = options->begin_mask(); + param.end_mask = options->end_mask(); + param.shrink_axis_mask = options->shrink_axis_mask(); + + loadOperationTo<ir::operation::StridedSlice>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadUnpack(const Operator *op, ir::Graph &subg) +{ + ir::operation::Unpack::Param param; + const auto *options = op->builtin_options_as_UnpackOptions(); + param.num = options->num(); + param.axis = options->axis(); + + loadOperationTo<ir::operation::Unpack>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadComparison(const Operator *op, ir::Graph &subg) +{ + ir::operation::Comparison::Param param; + const auto builtin_op = getBuiltinOperator(op); + + switch (builtin_op) + { + case BuiltinOperator::BuiltinOperator_EQUAL: + param.comparison_type = ir::operation::Comparison::ComparisonType::Equal; + break; + case BuiltinOperator::BuiltinOperator_NOT_EQUAL: + param.comparison_type = ir::operation::Comparison::ComparisonType::NotEqual; + break; + case BuiltinOperator::BuiltinOperator_GREATER_EQUAL: + param.comparison_type = ir::operation::Comparison::ComparisonType::GreaterEqual; + break; + case BuiltinOperator::BuiltinOperator_GREATER: + param.comparison_type = ir::operation::Comparison::ComparisonType::Greater; + break; + case BuiltinOperator::BuiltinOperator_LESS_EQUAL: + param.comparison_type = ir::operation::Comparison::ComparisonType::LessEqual; + break; + case BuiltinOperator::BuiltinOperator_LESS: + param.comparison_type = ir::operation::Comparison::ComparisonType::Less; + break; + default: + throw std::runtime_error( + std::string("Unsupported operation: ").append(EnumNameBuiltinOperator(builtin_op))); + } + + loadOperationTo<ir::operation::Comparison>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadEinsum(const Operator *op, ir::Graph &subg) +{ + ir::operation::Einsum::Param param; + if (op->custom_options() == nullptr) + { + throw std::runtime_error{"Einsum: empty equation"}; + } + else + { + const auto attr_map = getCustomOpAttrMap(op); + param.equation = attr_map["equation"].ToString(); + } + + const auto es = loadOperationTo<ir::operation::Einsum>(op, subg, param); + if (es->getInputs().size() != 2) + { + throw std::runtime_error{"Einsum: NYI input - only support two inputs"}; + } +} +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadFusedBatchNorm(const Operator *op, ir::Graph &subg) +{ + ir::operation::FusedBatchNorm::Param param; + if (op->custom_options() == nullptr) + { + throw std::runtime_error{"FusedBatchNorm: empty option"}; + } + else + { + const auto attr_map = getCustomOpAttrMap(op); + param.is_training = attr_map["is_training"].AsBool(); + param.epsilon = attr_map["epsilon"].AsFloat(); + param.data_format = attr_map["data_format"].ToString(); + } + + const auto fbn = loadOperationTo<ir::operation::FusedBatchNorm>(op, subg, param); + + if (fbn->getInputs().size() != 5) + { + throw std::runtime_error{"FusedBatchNorm: NYI input - only support five inputs"}; + } +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadOneHot(const Operator *op, ir::Graph &subg) +{ + if (op->inputs()->size() != 4 || op->outputs()->size() != 1) + throw std::runtime_error("OneHot Op has wrong number of input or output tensors."); + + // Set parameter + ir::operation::OneHot::Param param; + param.axis = op->builtin_options_as_OneHotOptions()->axis(); + + loadOperationTo<ir::operation::OneHot>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadIf(const Operator *op, ir::Graph &subg) +{ + const auto *options = op->builtin_options_as_IfOptions(); + const int32_t then_index = options->then_subgraph_index(); + const int32_t else_index = options->else_subgraph_index(); + + verifySubgraphIndex(then_index); + verifySubgraphIndex(else_index); + + ir::operation::If::Param param; + param.then_subg_index = ir::SubgraphIndex{static_cast<uint16_t>(then_index)}; + param.else_subg_index = ir::SubgraphIndex{static_cast<uint16_t>(else_index)}; + + loadOperationTo<ir::operation::If>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadWhile(const Operator *op, ir::Graph &subg) +{ + const auto *options = op->builtin_options_as_WhileOptions(); + const int32_t cond_index = options->cond_subgraph_index(); + const int32_t body_index = options->body_subgraph_index(); + + verifySubgraphIndex(cond_index); + verifySubgraphIndex(body_index); + + ir::operation::While::Param param; + param.cond_subg_index = ir::SubgraphIndex{static_cast<uint16_t>(cond_index)}; + param.body_subg_index = ir::SubgraphIndex{static_cast<uint16_t>(body_index)}; + + loadOperationTo<ir::operation::While>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadArgMinMax(const Operator *op, ir::Graph &subg, bool is_argmax) +{ + ir::operation::ArgMinMax::Param param; + const auto output_type = is_argmax ? op->builtin_options_as_ArgMaxOptions()->output_type() + : op->builtin_options_as_ArgMinOptions()->output_type(); + param.output_type = tensorTypeToDataType(output_type); + param.is_arg_max = is_argmax; + + loadOperationTo<ir::operation::ArgMinMax>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadLogSoftmax(const Operator *op, ir::Graph &subg) +{ + ir::operation::LogSoftmax::Param param; + // In tflite, beta is fixed to 1.0 and axis is fixed to -1. + param.beta = 1.0f; + param.axis = -1; + + loadOperationTo<ir::operation::LogSoftmax>(op, subg, param); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadLeakyRelu(const Operator *op, ir::Graph &subg) +{ + float alpha = op->builtin_options_as_LeakyReluOptions()->alpha(); + loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::LEAKY_RELU, alpha, + 1.f); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadUnidirectionalSequenceLSTM(const Operator *op, ir::Graph &subg) +{ + ir::operation::LSTM::Param param; + const auto *options = op->builtin_options_as_UnidirectionalSequenceLSTMOptions(); + param.activation = convertActivation(options->fused_activation_function()); + param.cell_threshold = options->cell_clip(); + param.projection_threshold = options->proj_clip(); + param.time_major = options->time_major(); + // The asymmetric_quantize_inputs option is unused yet + + ir::OperandIndexSequence inputs; + for (const std::int32_t idx : *op->inputs()) + { + inputs.append(tensorIdxToOperandIdx(idx)); + } + + ir::OperandIndexSequence outputs; + // loader doesn't support optional output tensor yet + if (op->outputs()->size() != 1) + { + auto builtin_code = getBuiltinOperator(op); + throw std::runtime_error(std::string("loader doesn't support optional output tensor yet for ") + .append(EnumNameBuiltinOperator(builtin_code))); + } + for (size_t i = 0; i < ir::operation::LSTM::Output::OUTPUT; ++i) + { + // Add optional outputs + outputs.append(ir::OperandIndex()); + } + outputs.append(tensorIdxToOperandIdx(op->outputs()->Get(0))); + + std::unique_ptr<ir::operation::LSTM> new_op(new ir::operation::LSTM(inputs, outputs, param)); + subg.addOperation(std::move(new_op)); +} + +template <typename LoaderDomain> +void BaseLoader<LoaderDomain>::loadOperation(const Operator *op, ir::Graph &subg) +{ + auto const builtin_op = getBuiltinOperator(op); + + switch (builtin_op) + { + case BuiltinOperator::BuiltinOperator_ADD_N: + loadOperationTo<ir::operation::AddN>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_CONV_2D: + loadConv2D(op, subg); + return; + case BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D: + loadPool2D(op, subg, ir::operation::Pool2D::PoolType::AVG); + return; + case BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D: + loadDepthwiseConv2D(op, subg); + return; + case BuiltinOperator::BuiltinOperator_TRANSPOSE_CONV: + loadTransposeConv(op, subg); + return; + case BuiltinOperator::BuiltinOperator_RESHAPE: + loadReshape(op, subg); + return; + case BuiltinOperator::BuiltinOperator_SOFTMAX: + loadSoftmax(op, subg); + return; + case BuiltinOperator::BuiltinOperator_MAX_POOL_2D: + loadPool2D(op, subg, ir::operation::Pool2D::PoolType::MAX); + return; + case BuiltinOperator::BuiltinOperator_CONCATENATION: + loadConcatenation(op, subg); + return; + case BuiltinOperator::BuiltinOperator_FLOOR: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::FLOOR); + return; + case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED: + loadFC(op, subg); + return; + case BuiltinOperator::BuiltinOperator_ADD: + loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::ADD); + return; + case BuiltinOperator::BuiltinOperator_SUB: + loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::SUB); + return; + case BuiltinOperator::BuiltinOperator_MUL: + loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::MUL); + return; + case BuiltinOperator::BuiltinOperator_DIV: + loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::DIV); + return; + case BuiltinOperator::BuiltinOperator_PACK: + loadPack(op, subg); + return; + case BuiltinOperator::BuiltinOperator_ELU: + loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::ELU); + return; + case BuiltinOperator::BuiltinOperator_RELU: + loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::RELU, + ir::operation::ElementwiseActivation::infinity, 0.f); + return; + case BuiltinOperator::BuiltinOperator_RELU_N1_TO_1: + loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::RELU, 1.f, + -1.f); + return; + case BuiltinOperator::BuiltinOperator_RELU6: + loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::RELU, 6.f, + 0.f); + return; + case BuiltinOperator::BuiltinOperator_RESIZE_BILINEAR: + loadResizeBilinear(op, subg); + return; + case BuiltinOperator::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: + loadResizeNearestNeighbor(op, subg); + return; + case BuiltinOperator::BuiltinOperator_RSQRT: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::RSQRT); + return; + case BuiltinOperator::BuiltinOperator_SELECT: + case BuiltinOperator::BuiltinOperator_SELECT_V2: + loadOperationTo<ir::operation::Select>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_SQRT: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::SQRT); + return; + case BuiltinOperator::BuiltinOperator_SQUARE: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::SQUARE); + return; + case BuiltinOperator::BuiltinOperator_SQUARED_DIFFERENCE: + loadOperationTo<ir::operation::SquaredDifference>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_TANH: + loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::TANH, 1.f, + 1.f); + return; + case BuiltinOperator::BuiltinOperator_TRANSPOSE: + loadOperationTo<ir::operation::Transpose>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_MEAN: + loadReduce(op, subg, ir::operation::Reduce::ReduceType::MEAN); + return; + case BuiltinOperator::BuiltinOperator_REDUCE_ANY: + loadReduce(op, subg, ir::operation::Reduce::ReduceType::ANY); + return; + case BuiltinOperator::BuiltinOperator_REDUCE_MAX: + loadReduce(op, subg, ir::operation::Reduce::ReduceType::MAX); + return; + case BuiltinOperator::BuiltinOperator_REVERSE_V2: + loadOperationTo<ir::operation::Reverse>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_PAD: + case BuiltinOperator::BuiltinOperator_PADV2: + loadOperationTo<ir::operation::Pad>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_LOGISTIC: + loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::LOGISTIC); + return; + case BuiltinOperator::BuiltinOperator_EXP: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::EXP); + return; + case BuiltinOperator::BuiltinOperator_EXPAND_DIMS: + loadOperationTo<ir::operation::ExpandDims>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_GATHER: + loadGather(op, subg); + return; + case BuiltinOperator::BuiltinOperator_SPACE_TO_BATCH_ND: + loadOperationTo<ir::operation::SpaceToBatchND>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_BATCH_TO_SPACE_ND: + loadOperationTo<ir::operation::BatchToSpaceND>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_SUM: + loadReduce(op, subg, ir::operation::Reduce::ReduceType::SUM); + return; + case BuiltinOperator::BuiltinOperator_CUSTOM: + loadCustom(op, subg); + return; + case BuiltinOperator::BuiltinOperator_SQUEEZE: + loadSqueeze(op, subg); + return; + case BuiltinOperator::BuiltinOperator_PRELU: + loadOperationTo<ir::operation::PReLU>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_SPLIT: + loadSplit(op, subg); + return; + case BuiltinOperator::BuiltinOperator_SPLIT_V: + loadSplitV(op, subg); + return; + case BuiltinOperator::BuiltinOperator_SLICE: + loadOperationTo<ir::operation::Slice>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_STRIDED_SLICE: + loadStridedSlice(op, subg); + return; + case BuiltinOperator::BuiltinOperator_UNPACK: + loadUnpack(op, subg); + return; + case BuiltinOperator::BuiltinOperator_FLOOR_DIV: + loadElementwiseBinary(op, subg, + ir::operation::ElementwiseBinary::ElementwiseBinaryType::FLOOR_DIV); + return; + case BuiltinOperator::BuiltinOperator_FLOOR_MOD: + loadElementwiseBinary(op, subg, + ir::operation::ElementwiseBinary::ElementwiseBinaryType::FLOOR_MOD); + return; + case BuiltinOperator::BuiltinOperator_MINIMUM: + loadElementwiseBinary(op, subg, ir::operation::ElementwiseBinary::ElementwiseBinaryType::MIN); + return; + case BuiltinOperator::BuiltinOperator_MAXIMUM: + loadElementwiseBinary(op, subg, ir::operation::ElementwiseBinary::ElementwiseBinaryType::MAX); + return; + case BuiltinOperator::BuiltinOperator_CAST: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::CAST); + return; + case BuiltinOperator::BuiltinOperator_EQUAL: + case BuiltinOperator::BuiltinOperator_NOT_EQUAL: + case BuiltinOperator::BuiltinOperator_GREATER_EQUAL: + case BuiltinOperator::BuiltinOperator_GREATER: + case BuiltinOperator::BuiltinOperator_LESS_EQUAL: + case BuiltinOperator::BuiltinOperator_LESS: + loadComparison(op, subg); + return; + case BuiltinOperator::BuiltinOperator_ONE_HOT: + loadOneHot(op, subg); + return; + case BuiltinOperator::BuiltinOperator_ABS: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ABS); + return; + case BuiltinOperator::BuiltinOperator_COS: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::COS); + return; + case BuiltinOperator::BuiltinOperator_SIN: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::SIN); + return; + case BuiltinOperator::BuiltinOperator_SHAPE: + loadOperationTo<ir::operation::Shape>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_REDUCE_PROD: + loadReduce(op, subg, ir::operation::Reduce::ReduceType::PROD); + return; + case BuiltinOperator::BuiltinOperator_IF: + loadIf(op, subg); + return; + case BuiltinOperator::BuiltinOperator_WHILE: + loadWhile(op, subg); + return; + case BuiltinOperator::BuiltinOperator_NEG: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::NEG); + return; + case BuiltinOperator::BuiltinOperator_ARG_MAX: + loadArgMinMax(op, subg, true); + return; + case BuiltinOperator::BuiltinOperator_ARG_MIN: + loadArgMinMax(op, subg, false); + return; + case BuiltinOperator::BuiltinOperator_LOG: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::LOG); + return; + case BuiltinOperator::BuiltinOperator_ROUND: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ROUND); + return; + case BuiltinOperator::BuiltinOperator_POW: + loadOperationTo<ir::operation::Pow>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_LOGICAL_NOT: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::LOGICAL_NOT); + return; + case BuiltinOperator::BuiltinOperator_LOGICAL_AND: + loadElementwiseBinary(op, subg, + ir::operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_AND); + return; + case BuiltinOperator::BuiltinOperator_LOGICAL_OR: + loadElementwiseBinary(op, subg, + ir::operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_OR); + return; + case BuiltinOperator::BuiltinOperator_FILL: + loadOperationTo<ir::operation::Fill>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_ZEROS_LIKE: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ZEROS_LIKE); + return; + case BuiltinOperator::BuiltinOperator_TILE: + loadOperationTo<ir::operation::Tile>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_RANGE: + loadOperationTo<ir::operation::Range>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_BATCH_MATMUL: + loadBatchMatMul(op, subg); + return; + case BuiltinOperator::BuiltinOperator_LOG_SOFTMAX: + loadLogSoftmax(op, subg); + return; + case BuiltinOperator::BuiltinOperator_QUANTIZE: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::QUANTIZE); + return; + case BuiltinOperator::BuiltinOperator_DEQUANTIZE: + loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::DEQUANTIZE); + return; + case BuiltinOperator::BuiltinOperator_SPACE_TO_DEPTH: + loadSpaceToDepth(op, subg); + return; + case BuiltinOperator::BuiltinOperator_L2_NORMALIZATION: + loadOperationTo<ir::operation::L2Normalization>(op, subg); + break; + case BuiltinOperator::BuiltinOperator_LEAKY_RELU: + loadLeakyRelu(op, subg); + return; + case BuiltinOperator::BuiltinOperator_RANK: + loadOperationTo<ir::operation::Rank>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: + loadUnidirectionalSequenceLSTM(op, subg); + return; + case BuiltinOperator::BuiltinOperator_DEPTH_TO_SPACE: + loadDepthToSpace(op, subg); + return; + case BuiltinOperator::BuiltinOperator_EMBEDDING_LOOKUP: + loadOperationTo<ir::operation::EmbeddingLookup>(op, subg); + return; + case BuiltinOperator::BuiltinOperator_HASHTABLE_LOOKUP: + loadOperationTo<ir::operation::HashtableLookup>(op, subg); + return; + default: + throw std::runtime_error( + std::string("Unsupported operation: ").append(EnumNameBuiltinOperator(builtin_op))); + } +} + +template <typename LoaderDomain> void BaseLoader<LoaderDomain>::loadModel() +{ + LoaderDomain::VerifyModelBuffer(*_verifier.get()); + _domain_model = LoaderDomain::GetModel(_base); + + auto model = std::make_unique<ir::Model>(); + // Version unused + // const auto version = _model->version(); + // Description unused + + // Load Metadata + auto const metadata_list = _domain_model->metadata(); + if (metadata_list != nullptr) + { + for (uint32_t i = 0; i < metadata_list->size(); ++i) + { + const auto metadata = metadata_list->Get(i); + if (metadata->name() == nullptr) + continue; // metadata should have name + + std::unique_ptr<const ir::Data> data = loadMetadata(metadata->buffer()); + model->add_metadata(metadata->name()->str(), std::move(data)); + } + } + + // const auto *description = _model->description(); + // Load subgraphs and map operations on subgraph + const auto subgraphs = _domain_model->subgraphs(); + if (subgraphs->size() - 1 > ir::SubgraphIndex::max()) + throw std::runtime_error{"The number of subgraphs cannot exceed " + + std::to_string(ir::SubgraphIndex::max() + 1)}; + for (uint16_t subgraph_index = 0; subgraph_index < subgraphs->size(); ++subgraph_index) + { + auto subg = loadSubgraph((*_domain_model->subgraphs())[subgraph_index]); + // NOTE: Used () instead of {}, which does not check narrowing. + // It is okay since overflow is checked the above if-statement. + model->push(ir::SubgraphIndex(subgraph_index), std::move(subg)); + } + _model = std::move(model); +} + +} // namespace loader +} // namespace onert + +#endif //__ONERT_LOADER_BASE_LOADER_H__ diff --git a/runtime/onert/core/src/loader/CircleLoader.cc b/runtime/onert/core/src/loader/CircleLoader.cc new file mode 100644 index 000000000..442a0f518 --- /dev/null +++ b/runtime/onert/core/src/loader/CircleLoader.cc @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loader/CircleLoader.h" + +#include "BaseLoader.h" +#include "circle_schema_generated.h" + +namespace onert +{ +namespace loader +{ + +namespace +{ + +struct LoaderDomain +{ + using Verifier = flatbuffers::Verifier; + using ActivationFunctionType = circle::ActivationFunctionType; + using Buffer = circle::Buffer; + using BuiltinOperator = circle::BuiltinOperator; + using CustomOptionsFormat = circle::CustomOptionsFormat; + using Metadata = circle::Metadata; + using Model = circle::Model; + using Operator = circle::Operator; + using Padding = circle::Padding; + using Pool2DOptions = circle::Pool2DOptions; + using Tensor = circle::Tensor; + using TensorType = circle::TensorType; + using SubGraph = circle::SubGraph; + using DimensionType = circle::DimensionType; + using SparseIndexVector = circle::SparseIndexVector; + + static const char *EnumNameBuiltinOperator(BuiltinOperator e) + { + return circle::EnumNameBuiltinOperator(e); + } + static const char *EnumNameActivationFunctionType(ActivationFunctionType e) + { + return circle::EnumNameActivationFunctionType(e); + } + static const char *EnumNameTensorType(TensorType e) { return circle::EnumNameTensorType(e); } + static const Model *GetModel(const void *buf) { return circle::GetModel(buf); } + static bool VerifyModelBuffer(Verifier &verifier) { return circle::VerifyModelBuffer(verifier); } +}; + +class CircleLoader final : public loader::BaseLoader<LoaderDomain> +{ +protected: + // Different option name + // Circle: adjoint_lhs, adjoint_rhs + // TFLite: adj_x, adj_y + void loadBatchMatMul(const Operator *op, ir::Graph &subg); + + // Only circle operations + void loadInstanceNorm(const Operator *op, ir::Graph &subg); + void loadBCQFullyConnected(const Operator *op, ir::Graph &subg); + void loadBCQGather(const Operator *op, ir::Graph &subg); + +public: + using BaseLoader::BaseLoader; + + bool allowOptionalInputTensor(BuiltinOperator op) override + { + switch (op) + { + case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED: + case BuiltinOperator::BuiltinOperator_BCQ_FULLY_CONNECTED: + case BuiltinOperator::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: + return true; + default: + return false; + } + } + +private: + std::unique_ptr<ir::Graph> loadSubgraph(const circle::SubGraph *circle_subg) override + { + auto subg = std::make_unique<ir::Graph>(); + // Load tensors + _tensor_to_operand.resize(circle_subg->tensors()->size()); + for (flatbuffers::uoffset_t i = 0; i < circle_subg->tensors()->size(); ++i) + { + _tensor_to_operand[i] = loadOperand(circle_subg->tensors()->Get(i), *subg); + subg->operands().at(_tensor_to_operand[i]).setOriginIndex(ir::OriginIndex(i)); + } + // Set inputs + for (const std::int32_t input_ind : *circle_subg->inputs()) + { + subg->addInput(tensorIdxToOperandIdx(input_ind), + _tensor_names.at(_tensor_to_operand[input_ind])); + } + // Set outputs + for (const std::int32_t output_ind : *circle_subg->outputs()) + { + subg->addOutput(tensorIdxToOperandIdx(output_ind), + _tensor_names.at(_tensor_to_operand[output_ind])); + } + // Create operations + for (const auto *op : *circle_subg->operators()) + { + CircleLoader::loadOperation(op, *subg); + } + + // TODO Remove frontend layout feature + subg->setLayout(ir::Layout::NHWC); + + subg->verify(); + + return subg; + } + + void loadOperation(const circle::Operator *op, ir::Graph &subg) + { + auto const builtin_op = getBuiltinOperator(op); + + switch (builtin_op) + { + case circle::BuiltinOperator::BuiltinOperator_BATCH_MATMUL: + loadBatchMatMul(op, subg); + return; + case circle::BuiltinOperator::BuiltinOperator_INSTANCE_NORM: + loadInstanceNorm(op, subg); + return; + case circle::BuiltinOperator::BuiltinOperator_BCQ_FULLY_CONNECTED: + loadBCQFullyConnected(op, subg); + return; + case circle::BuiltinOperator::BuiltinOperator_BCQ_GATHER: + loadBCQGather(op, subg); + return; + default: + BaseLoader::loadOperation(op, subg); + return; + } + } +}; + +void CircleLoader::loadBatchMatMul(const Operator *op, ir::Graph &subg) +{ + ir::OperandIndexSequence inputs; + ir::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + ir::operation::BatchMatMul::Param param; + const auto *options = op->builtin_options_as_BatchMatMulOptions(); + + param.adj_x = options->adjoint_lhs(); + param.adj_y = options->adjoint_rhs(); + + std::unique_ptr<ir::Operation> new_op(new ir::operation::BatchMatMul(inputs, outputs, param)); + subg.addOperation(std::move(new_op)); +} + +void CircleLoader::loadInstanceNorm(const Operator *op, ir::Graph &subg) +{ + ir::OperandIndexSequence inputs; + ir::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + ir::operation::InstanceNorm::Param param; + const auto *options = op->builtin_options_as_InstanceNormOptions(); + + param.activation = convertActivation(options->fused_activation_function()); + // Use default value 1e-5 if value of epsilon is zero + param.epsilon = options->epsilon() == 0.f ? 1e-5 : options->epsilon(); + + std::unique_ptr<ir::Operation> new_op(new ir::operation::InstanceNorm(inputs, outputs, param)); + subg.addOperation(std::move(new_op)); +} + +void CircleLoader::loadBCQGather(const Operator *op, ir::Graph &subg) +{ + ir::OperandIndexSequence inputs; + ir::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + ir::operation::BCQGather::Param param; + const auto *options = op->builtin_options_as_BCQGatherOptions(); + param.input_hidden_size = options->input_hidden_size(); + param.axis = options->axis(); + + std::unique_ptr<ir::Operation> new_op(new ir::operation::BCQGather(inputs, outputs, param)); + subg.addOperation(std::move(new_op)); +} + +void CircleLoader::loadBCQFullyConnected(const Operator *op, ir::Graph &subg) +{ + ir::OperandIndexSequence inputs; + ir::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + ir::operation::BCQFullyConnected::Param param; + const auto *options = op->builtin_options_as_BCQFullyConnectedOptions(); + param.weights_hidden_size = options->weights_hidden_size(); + param.activation = convertActivation(options->fused_activation_function()); + + std::unique_ptr<ir::Operation> new_op( + new ir::operation::BCQFullyConnected(inputs, outputs, param)); + subg.addOperation(std::move(new_op)); +} + +} // namespace + +std::unique_ptr<ir::Model> loadCircleModel(const std::string &filename) +{ + auto model = std::make_unique<ir::Model>(); + CircleLoader loader(model); + loader.loadFromFile(filename); + return model; +} + +std::unique_ptr<ir::Model> loadCircleModel(uint8_t *buffer, size_t size) +{ + auto model = std::make_unique<ir::Model>(); + CircleLoader loader(model); + loader.loadFromBuffer(buffer, size); + return model; +} + +} // namespace loader +} // namespace onert diff --git a/runtime/onert/core/src/loader/ModelLoader.cc b/runtime/onert/core/src/loader/ModelLoader.cc new file mode 100644 index 000000000..1f3b4673c --- /dev/null +++ b/runtime/onert/core/src/loader/ModelLoader.cc @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loader/ModelLoader.h" + +#include "loader/ILoader.h" + +#include <dlfcn.h> + +namespace onert +{ +namespace loader +{ + +std::unique_ptr<ir::Model> loadModel(const std::string &filename, const std::string &type) +{ + // Custom loader library name should be lib<type>_loader.so + std::string libname = "lib" + type + "_loader.so"; + + // Open custom loader library + void *handle = dlopen(libname.c_str(), RTLD_LAZY); + if (!handle) + throw std::runtime_error("Failed to open " + type + " loader"); + + // Get custom loader create function + using create_func_t = ILoader *(*)(); + auto create_fn = reinterpret_cast<create_func_t>(dlsym(handle, "onert_loader_create")); + if (!create_fn) + { + dlclose(handle); + throw std::runtime_error("Failed to find loader create function"); + } + + // Get custom loader destroy function + using destroy_func_t = void (*)(ILoader *); + auto destroy_fn = reinterpret_cast<destroy_func_t>(dlsym(handle, "onert_loader_destroy")); + if (!destroy_fn) + { + dlclose(handle); + throw std::runtime_error("Failed to find loader destroy function"); + } + + // Create custom loader + auto loader = create_fn(); + if (!loader) + { + dlclose(handle); + throw std::runtime_error("Failed to find loader create function"); + } + + // Load model + auto model = loader->loadFromFile(filename); + + // Destroy custom loader + destroy_fn(loader); + + // Close custom loader library + // + // NOTE: + // It assumes that custom loader will not be used frequently on runtime session. + // If custom loader is used frequently, it should not close custom loader library and + // save handler to reuse it. + dlclose(handle); + + if (model) + return model; + + throw std::runtime_error("Failed to load model " + filename); +} + +} // namespace loader +} // namespace onert diff --git a/runtime/onert/core/src/loader/TFLiteLoader.cc b/runtime/onert/core/src/loader/TFLiteLoader.cc new file mode 100644 index 000000000..745f39006 --- /dev/null +++ b/runtime/onert/core/src/loader/TFLiteLoader.cc @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loader/TFLiteLoader.h" + +#include "BaseLoader.h" +#include "tflite_schema_generated.h" + +namespace onert +{ +namespace loader +{ + +namespace +{ + +struct LoaderDomain +{ + using Verifier = flatbuffers::Verifier; + using ActivationFunctionType = onert_tflite::ActivationFunctionType; + using Buffer = onert_tflite::Buffer; + using BuiltinOperator = onert_tflite::BuiltinOperator; + using CustomOptionsFormat = onert_tflite::CustomOptionsFormat; + using Model = onert_tflite::Model; + using Metadata = onert_tflite::Metadata; + using Operator = onert_tflite::Operator; + using Padding = onert_tflite::Padding; + using Pool2DOptions = onert_tflite::Pool2DOptions; + using Tensor = onert_tflite::Tensor; + using TensorType = onert_tflite::TensorType; + using SubGraph = onert_tflite::SubGraph; + using DimensionType = onert_tflite::DimensionType; + using SparseIndexVector = onert_tflite::SparseIndexVector; + + static const char *EnumNameBuiltinOperator(BuiltinOperator e) + { + return onert_tflite::EnumNameBuiltinOperator(e); + } + static const char *EnumNameActivationFunctionType(ActivationFunctionType e) + { + return onert_tflite::EnumNameActivationFunctionType(e); + } + static const char *EnumNameTensorType(TensorType e) + { + return onert_tflite::EnumNameTensorType(e); + } + static const Model *GetModel(const void *buf) { return onert_tflite::GetModel(buf); } + static bool VerifyModelBuffer(Verifier &verifier) + { + return onert_tflite::VerifyModelBuffer(verifier); + } +}; + +class TFLiteLoader final : public loader::BaseLoader<LoaderDomain> +{ +protected: + // Different option name + // Circle: adjoint_lhs, adjoint_rhs + // TFLite: adj_x, adj_y + void loadBatchMatMul(const Operator *op, ir::Graph &subg); + +public: + using BaseLoader::BaseLoader; + + bool allowOptionalInputTensor(BuiltinOperator op) override + { + switch (op) + { + case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED: + case BuiltinOperator::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: + return true; + default: + return false; + } + } + +private: + std::unique_ptr<ir::Graph> loadSubgraph(const onert_tflite::SubGraph *tflite_subg) override + { + auto subg = std::make_unique<ir::Graph>(); + // Load tensors + _tensor_to_operand.resize(tflite_subg->tensors()->size()); + for (flatbuffers::uoffset_t i = 0; i < tflite_subg->tensors()->size(); ++i) + { + _tensor_to_operand[i] = loadOperand(tflite_subg->tensors()->Get(i), *subg); + } + // Set inputs + for (const std::int32_t input_ind : *tflite_subg->inputs()) + { + subg->addInput(tensorIdxToOperandIdx(input_ind), + _tensor_names.at(_tensor_to_operand[input_ind])); + } + // Set outputs + for (const std::int32_t output_ind : *tflite_subg->outputs()) + { + subg->addOutput(tensorIdxToOperandIdx(output_ind), + _tensor_names.at(_tensor_to_operand[output_ind])); + } + // Create operations + for (const auto *op : *tflite_subg->operators()) + { + loadOperation(op, *subg); + } + + subg->verify(); + + return subg; + } + + void loadOperation(const onert_tflite::Operator *op, ir::Graph &subg) + { + auto const builtin_op = getBuiltinOperator(op); + + switch (builtin_op) + { + case onert_tflite::BuiltinOperator::BuiltinOperator_BATCH_MATMUL: + loadBatchMatMul(op, subg); + return; + default: + BaseLoader::loadOperation(op, subg); + return; + } + } +}; + +void TFLiteLoader::loadBatchMatMul(const Operator *op, ir::Graph &subg) +{ + ir::OperandIndexSequence inputs; + ir::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + ir::operation::BatchMatMul::Param param; + const auto *options = op->builtin_options_as_BatchMatMulOptions(); + + param.adj_x = options->adj_x(); + param.adj_y = options->adj_y(); + + std::unique_ptr<ir::Operation> new_op(new ir::operation::BatchMatMul(inputs, outputs, param)); + subg.addOperation(std::move(new_op)); +} + +} // namespace + +std::unique_ptr<ir::Model> loadTFLiteModel(const std::string &filename) +{ + auto model = std::make_unique<ir::Model>(); + TFLiteLoader loader(model); + loader.loadFromFile(filename); + return model; +} + +} // namespace loader +} // namespace onert diff --git a/runtime/onert/core/src/loader/TrainInfoLoader.cc b/runtime/onert/core/src/loader/TrainInfoLoader.cc new file mode 100644 index 000000000..bb75daa6f --- /dev/null +++ b/runtime/onert/core/src/loader/TrainInfoLoader.cc @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loader/TrainInfoLoader.h" + +#include "circle_traininfo_generated.h" +#include "flatbuffers/flatbuffers.h" + +namespace onert +{ +namespace loader +{ + +const char *const TRAININFO_METADATA_NAME = "CIRCLE_TRAINING"; + +namespace +{ + +ir::train::OptimizerInfo loadOptimizerInfo(const circle::ModelTraining *circle_model) +{ + assert(circle_model != nullptr); + + // fill ir_opt from cirlce_opt + ir::train::OptimizerInfo ir_opt; + const circle::Optimizer circle_opt = circle_model->optimizer(); + + switch (circle_opt) + { + case circle::Optimizer_SGD: + ir_opt.optim_code = ir::train::OptimizerCode::SGD; + ir_opt.learning_rate = circle_model->optimizer_opt_as_SGDOptions()->learning_rate(); + break; + case circle::Optimizer_ADAM: + ir_opt.optim_code = ir::train::OptimizerCode::Adam; + ir_opt.learning_rate = circle_model->optimizer_opt_as_AdamOptions()->learning_rate(); + break; + default: + throw std::runtime_error("unknown optimzer"); + } + return ir_opt; +} + +ir::train::LossInfo loadLossInfo(const circle::ModelTraining *circle_model) +{ + assert(circle_model != nullptr); + + // fill ir_loss from circle_loss + ir::train::LossInfo ir_loss; + const circle::LossFn circle_loss = circle_model->lossfn(); + const circle::LossReductionType circle_loss_rdt = circle_model->loss_reduction_type(); + + switch (circle_loss) + { + case circle::LossFn::LossFn_CATEGORICAL_CROSSENTROPY: + ir_loss.loss_code = ir::train::LossCode::CategoricalCrossentropy; + break; + case circle::LossFn::LossFn_MEAN_SQUARED_ERROR: + ir_loss.loss_code = ir::train::LossCode::MeanSquaredError; + break; + case circle::LossFn::LossFn_SPARSE_CATEGORICAL_CROSSENTROPY: + // TODO enable this conversion after core support sparse_categorial_crossentropy + throw std::runtime_error{"'sparse_categorical_crossentropy' is not supported yet"}; + default: + throw std::runtime_error{"unknown loss function"}; + } + + switch (circle_loss_rdt) + { + case circle::LossReductionType::LossReductionType_SumOverBatchSize: + ir_loss.reduction_type = ir::train::LossReductionType::SumOverBatchSize; + break; + case circle::LossReductionType::LossReductionType_Sum: + ir_loss.reduction_type = ir::train::LossReductionType::Sum; + break; + default: + throw std::runtime_error{"unknown loss reduction type"}; + } + + return ir_loss; +} + +std::set<ir::OperationIndex> loadTrainableOps(const circle::ModelTraining *circle_model) +{ + assert(circle_model != nullptr); + + std::set<ir::OperationIndex> ir_trainable_ops; + const auto lists = circle_model->trainable_ops(); + if (lists != nullptr) + { + for (::flatbuffers::uoffset_t i = 0; i < lists->size(); ++i) + { + const uint32_t op_index = lists->Get(i); + ir_trainable_ops.emplace(ir::OperationIndex{op_index}); + } + } + return ir_trainable_ops; +} +} // namespace + +std::unique_ptr<ir::train::TrainingInfo> loadTrainingInfo(const uint8_t *buffer, const size_t size) +{ + assert(buffer != nullptr); + + flatbuffers::Verifier v(buffer, size); + bool verified = circle::VerifyModelTrainingBuffer(v); + if (not verified) + throw std::runtime_error{"TrainingInfo buffer is not accessible"}; + + const circle::ModelTraining *circle_model = + circle::GetModelTraining(static_cast<const void *>(buffer)); + + assert(circle_model != nullptr); + + auto tinfo = std::make_unique<ir::train::TrainingInfo>(); + { + tinfo->setVersion(circle_model->version()); + tinfo->setBatchSize(circle_model->batch_size()); + tinfo->setOptimizerInfo(loadOptimizerInfo(circle_model)); + tinfo->setLossInfo(loadLossInfo(circle_model)); + tinfo->setTrainableOps(loadTrainableOps(circle_model)); + } + return tinfo; +} + +} // namespace loader +} // namespace onert diff --git a/runtime/onert/core/src/loader/tflite_schema.fbs b/runtime/onert/core/src/loader/tflite_schema.fbs new file mode 100644 index 000000000..f7997528e --- /dev/null +++ b/runtime/onert/core/src/loader/tflite_schema.fbs @@ -0,0 +1,1308 @@ +// Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. +// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. +// Version 3a: Add new builtin op code field. Has backward compatibility with +// version 3. +// Version 3b: Rename fields in SignatureDef. Has backward compatibility with +// version 3 and 3a. + +// Change namespace to onert_tflite +namespace onert_tflite; + +// This corresponds to the version. +file_identifier "TFL3"; +// File extension of any written files. +file_extension "tflite"; + +// IMPORTANT: All new members of tables, enums and unions must be added at the +// end to ensure backwards compatibility. + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, + BOOL = 6, + INT16 = 7, + COMPLEX64 = 8, + INT8 = 9, + FLOAT64 = 10, + COMPLEX128 = 11, + UINT64 = 12, + // Experimental: Resource and variant types are experimental, that are subject + // to change. Do not implement custom kernels using resource & variant types + // now. + RESOURCE = 13, + VARIANT = 14, + UINT32 = 15, + UINT16 = 16 +} + +// Custom quantization parameters for experimenting with new quantization +// techniques. +table CustomQuantization { + custom:[ubyte] (force_align: 16); +} + +// Represents a specific quantization technique's parameters. +union QuantizationDetails { + CustomQuantization, +} + +// Parameters for converting a quantized tensor back to float. +table QuantizationParameters { + // These four parameters are the asymmetric linear quantization parameters. + // Given a quantized value q, the corresponding float value f should be: + // f = scale * (q - zero_point) + // For other quantization types, the QuantizationDetails below is used. + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; // For dequantizing the tensor's values. + zero_point:[long]; + + // If this is not none, the other quantization parameters (i.e. min, max, + // scale, zero_point fields above) are ignored and the value of the + // QuantizationDetails union should be used. + details:QuantizationDetails; + + // Specifies the dimension of the Tensor's shape that the scales and + // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] + // with quantization params: + // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1 + // will be quantized across the second dimension of t. + // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 + // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 + // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 + quantized_dimension:int; +} + +// Sparse tensors. +// We use a modification of the TACO format. +// Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf +// +// To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1), +// potentially with a k-dimensional block (0 <= k <= n) with dims +// (dn, ..., dn+k-1), the format needs to specify: +// 1. In what order to traverse these dimensions. For example, to store a 2-D +// matrix in row major order, the traversal order would be (d0, d1), +// whereas to store it in column major order, the traversal order would be +// (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order +// could be (d0, d1, d2, d3). +// 2. How each block dimension in (dn, ..., dn+k-1) maps to the original +// tensor dimension in (d0, ..., dn-1). +// 3. In the traversal order defined above, the format (dense vs. sparse) and +// index metadata for each dimension. For a dense dimension, this is just +// the size of that dimension. For a sparse dimension, it's the same as +// the compressed index defined in the Compressed Sparse Row (CSR) format. +// (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html) + +// The storage type for a dimension. Currently we support: +// 1. DENSE: each coordinate in this dimension is stored implicitly. +// 2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The +// compression technique is the same what CSR uses. +// More types like a sparse dimension with a different compression technique +// could be added to the list in the future. +enum DimensionType : byte { + DENSE = 0, + SPARSE_CSR = 1, +} + +table Int32Vector { + values:[int]; +} + +table Uint16Vector { + values:[ushort] (force_align: 4); +} + +table Uint8Vector { + values:[ubyte] (force_align: 4); +} + +// Variable-typed buffer to store the index metadata for a sparse dimension. +// The widest type is Int32 instead of UInt32 because tensor's shape is a int32 +// vector. We don't want the per-dimensional index to overflow that range. +union SparseIndexVector { + Int32Vector, + Uint16Vector, + Uint8Vector +} + +table DimensionMetadata { + // Whether a dimension is dense or sparse. + format:DimensionType; + // Index metadata used for a dimension. + // - If format is DimensionType.DENSE then we use the dense_size field to + // store the size of that dimension. Each index in that dimension is + // stored implicitly. + // - If format is DimensionType.SPARSE_CSR then we use array_segments and + // array_indices to encode that dimension. array_segments represents how + // to segment the indices array, each segment corresponds to one element + // in the previous dimension. array_indices represents the index of the + // non-zero elements within this dimension (as those in the CSR matrix + // format, where the first array is row pointers and the second array is + // column indices). + dense_size:int; + array_segments:SparseIndexVector; + array_indices:SparseIndexVector; +} + +// Parameters to encode a sparse TfLite tensor. +table SparsityParameters { + // The traversal order of the dimensions defined in the `shape` field of the + // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1, + // ..., dn-1), + // - if not block sparse, the traversal_order is just a permutation of (d0, + // ..., dn-1). For example, a 2-D matrix stored in row-major order would + // have traversal_order = (d0, d1). + // - if block sparse with a k-dimensional block (0 <= k <= n), the + // traversal_order has n + k elements. The first n elements are still a + // permutation of (d0, ..., dn-1). The lask k elements are a permutation + // of (dn, ..., dn+k-1), defining how to traverse a block internally. For + // example, a 2-D matrix with 2-D blocks, both stored in row-major order + // would have traversal_order = (d0, d1, d2, d3). + traversal_order:[int]; + // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n), + // stores how a block dimension in (dn, ..., dn+k-1) maps to the original + // tensor dimension in (d0, ..., dn). + // It's stored in the order of (dn, ..., dn+k-1). + // If not block-sparse, this field is NULL. + block_map:[int]; + // In the traversal order defined above, the metadata needed for + // each dimension to locate the non-zero values in the original dense tensor. + // The size of the dim_metadata array = the size of the traversal_order array + // = n + k. + dim_metadata:[DimensionMetadata]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, height, width, number of channels] (That's + // Tensorflow's NHWC). + shape:[int]; + type:TensorType; + // An index that refers to the buffers table at the root of the model. Or, + // if there is no data buffer associated (i.e. intermediate results), then + // this is 0 (which refers to an always existent empty buffer). + // + // The data_buffer itself is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. + buffer:uint; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; + + // Parameters to encode a sparse tensor. See the example in + // tensorflow/lite/testdata/sparse_tensor.json. + sparsity:SparsityParameters; // Optional. + + // Encodes `shape` with unknown dimensions. Unknown dimensions are + // represented with -1. + shape_signature:[int]; // Optional. + + // If false, the rank or the number of tensor dimensions is unknown. + // If false, "shape" must be []. + has_rank: bool = false; +} + +// A list of builtin operators. Builtin operators are slightly faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +// LINT.IfChange +enum BuiltinOperator : int32 { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + DEPTH_TO_SPACE = 5, + DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + MUL = 18, + RELU = 19, + // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed + // since different model developers use RELU1 in different ways. Never + // create another op called RELU1. + RELU_N1_TO_1 = 20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + EMBEDDING_LOOKUP_SPARSE = 33, + PAD = 34, + UNIDIRECTIONAL_SEQUENCE_RNN = 35, + GATHER = 36, + BATCH_TO_SPACE_ND = 37, + SPACE_TO_BATCH_ND = 38, + TRANSPOSE = 39, + MEAN = 40, + SUB = 41, + DIV = 42, + SQUEEZE = 43, + UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + STRIDED_SLICE = 45, + BIDIRECTIONAL_SEQUENCE_RNN = 46, + EXP = 47, + TOPK_V2 = 48, + SPLIT = 49, + LOG_SOFTMAX = 50, + // DELEGATE is a special op type for the operations which are delegated to + // other backends. + // WARNING: Experimental interface, subject to change + DELEGATE = 51, + BIDIRECTIONAL_SEQUENCE_LSTM = 52, + CAST = 53, + PRELU = 54, + MAXIMUM = 55, + ARG_MAX = 56, + MINIMUM = 57, + LESS = 58, + NEG = 59, + PADV2 = 60, + GREATER = 61, + GREATER_EQUAL = 62, + LESS_EQUAL = 63, + SELECT = 64, + SLICE = 65, + SIN = 66, + TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, + SUM = 74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, + POW = 78, + ARG_MIN = 79, + FAKE_QUANT = 80, + REDUCE_PROD = 81, + REDUCE_MAX = 82, + PACK = 83, + LOGICAL_OR = 84, + ONE_HOT = 85, + LOGICAL_AND = 86, + LOGICAL_NOT = 87, + UNPACK = 88, + REDUCE_MIN = 89, + FLOOR_DIV = 90, + REDUCE_ANY = 91, + SQUARE = 92, + ZEROS_LIKE = 93, + FILL = 94, + FLOOR_MOD = 95, + RANGE = 96, + RESIZE_NEAREST_NEIGHBOR = 97, + LEAKY_RELU = 98, + SQUARED_DIFFERENCE = 99, + MIRROR_PAD = 100, + ABS = 101, + SPLIT_V = 102, + UNIQUE = 103, + CEIL = 104, + REVERSE_V2 = 105, + ADD_N = 106, + GATHER_ND = 107, + COS = 108, + WHERE = 109, + RANK = 110, + ELU = 111, + REVERSE_SEQUENCE = 112, + MATRIX_DIAG = 113, + QUANTIZE = 114, + MATRIX_SET_DIAG = 115, + ROUND = 116, + HARD_SWISH = 117, + IF = 118, + WHILE = 119, + NON_MAX_SUPPRESSION_V4 = 120, + NON_MAX_SUPPRESSION_V5 = 121, + SCATTER_ND = 122, + SELECT_V2 = 123, + DENSIFY = 124, + SEGMENT_SUM = 125, + BATCH_MATMUL = 126, + PLACEHOLDER_FOR_GREATER_OP_CODES = 127, + CUMSUM = 128, + CALL_ONCE = 129, + BROADCAST_TO = 130, + RFFT2D = 131, + CONV_3D = 132, + IMAG=133, + REAL=134, + COMPLEX_ABS=135, + HASHTABLE = 136, + HASHTABLE_FIND = 137, + HASHTABLE_IMPORT = 138, + HASHTABLE_SIZE = 139, + REDUCE_ALL = 140, + CONV_3D_TRANSPOSE = 141, + VAR_HANDLE = 142, + READ_VARIABLE = 143, + ASSIGN_VARIABLE = 144, + BROADCAST_ARGS = 145, + RANDOM_STANDARD_NORMAL = 146, + BUCKETIZE = 147, + RANDOM_UNIFORM = 148, + MULTINOMIAL = 149, + GELU = 150, + DYNAMIC_UPDATE_SLICE = 151, + RELU_0_TO_1 = 152, + UNSORTED_SEGMENT_PROD = 153, + UNSORTED_SEGMENT_MAX = 154, + UNSORTED_SEGMENT_SUM = 155, + ATAN2 = 156 +} +// LINT.ThenChange(nnapi_linter/linter.proto) + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, + EmbeddingLookupSparseOptions, + MulOptions, + PadOptions, + GatherOptions, + BatchToSpaceNDOptions, + SpaceToBatchNDOptions, + TransposeOptions, + ReducerOptions, + SubOptions, + DivOptions, + SqueezeOptions, + SequenceRNNOptions, + StridedSliceOptions, + ExpOptions, + TopKV2Options, + SplitOptions, + LogSoftmaxOptions, + CastOptions, + DequantizeOptions, + MaximumMinimumOptions, + ArgMaxOptions, + LessOptions, + NegOptions, + PadV2Options, + GreaterOptions, + GreaterEqualOptions, + LessEqualOptions, + SelectOptions, + SliceOptions, + TransposeConvOptions, + SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, + EqualOptions, + NotEqualOptions, + ShapeOptions, + PowOptions, + ArgMinOptions, + FakeQuantOptions, + PackOptions, + LogicalOrOptions, + OneHotOptions, + LogicalAndOptions, + LogicalNotOptions, + UnpackOptions, + FloorDivOptions, + SquareOptions, + ZerosLikeOptions, + FillOptions, + BidirectionalSequenceLSTMOptions, + BidirectionalSequenceRNNOptions, + UnidirectionalSequenceLSTMOptions, + FloorModOptions, + RangeOptions, + ResizeNearestNeighborOptions, + LeakyReluOptions, + SquaredDifferenceOptions, + MirrorPadOptions, + AbsOptions, + SplitVOptions, + UniqueOptions, + ReverseV2Options, + AddNOptions, + GatherNdOptions, + CosOptions, + WhereOptions, + RankOptions, + ReverseSequenceOptions, + MatrixDiagOptions, + QuantizeOptions, + MatrixSetDiagOptions, + HardSwishOptions, + IfOptions, + WhileOptions, + DepthToSpaceOptions, + NonMaxSuppressionV4Options, + NonMaxSuppressionV5Options, + ScatterNdOptions, + SelectV2Options, + DensifyOptions, + SegmentSumOptions, + BatchMatMulOptions, + CumsumOptions, + CallOnceOptions, + BroadcastToOptions, + Rfft2dOptions, + Conv3DOptions, + HashtableOptions, + HashtableFindOptions, + HashtableImportOptions, + HashtableSizeOptions, + VarHandleOptions, + ReadVariableOptions, + AssignVariableOptions, + RandomOptions, + BucketizeOptions, + GeluOptions, + DynamicUpdateSliceOptions, + UnsortedSegmentProdOptions, + UnsortedSegmentMaxOptions, + UnsortedSegmentSumOptions, + ATan2Options +} + +// LINT.IfChange +enum Padding : byte { SAME, VALID } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +// LINT.IfChange +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU_N1_TO_1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +// Options for both Conv3D and Conv3DTranspose. +table Conv3DOptions { + padding:Padding; + stride_d:int; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_d_factor:int = 1; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + // Parameters for DepthwiseConv version 1 or above. + padding:Padding; + stride_w:int; + stride_h:int; + // `depth_multiplier` is redundant. It's used by CPU kernels in + // TensorFlow 2.0 or below, but ignored in versions above. + // See comments in lite/c/builtin_op_data.h for more details. + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; + // Parameters for DepthwiseConv version 2 or above. + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; + // For weights-only quantization, use asymmetric quantization for non + // constant inputs at evaluation time. + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow dynamic_rnn with RNNCell. +table SequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. +table BidirectionalSequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + merge_outputs: bool; + asymmetric_quantize_inputs:bool; +} + +// LINT.IfChange +enum FullyConnectedOptionsWeightsFormat: byte { + DEFAULT = 0, + SHUFFLED4x16INT8 = 1, +} +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + // Parameters for FullyConnected version 1 or above. + fused_activation_function:ActivationFunctionType; + + // Parameters for FullyConnected version 2 or above. + weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; + + // Parameters for FullyConnected version 5 or above. + // If set to true, then the number of dimension is preserved. Furthermore, + // all but the last dimension of the input and output shapes will be equal. + keep_num_dims: bool; + + // Parameters for FullyConnected version 7 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; + // Parameters supported by version 3. + pot_scale_int16:bool = true; +} + +table MulOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + // This field is currently ignored in the L2 Norm Op. + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +// LINT.IfChange +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + // Parameters for LSTM version 1 or above. + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; + + // Parameters for LSTM version 4 or above. + asymmetric_quantize_inputs: bool; +} + +// An implementation of TensorFlow dynamic_rnn with LSTMCell. +table UnidirectionalSequenceLSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true then first dimension is sequence, otherwise batch. + time_major:bool; + + // Parameter for Unidirectional Sequence LSTM version 4. + asymmetric_quantize_inputs:bool; +} + +table BidirectionalSequenceLSTMOptions { + // Parameters supported by version 1: + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true, store the outputs of both directions into the first output. + merge_outputs: bool; + + // Parameters supported by version 2: + // If true then first dimension is sequence, otherwise batch. + // Version 1 implementations assumed time_major to be true, so this default + // value should never change. + time_major: bool = true; + + // Parameters for version 3 or above. + asymmetric_quantize_inputs:bool; +} + +table ResizeBilinearOptions { + new_height: int (deprecated); + new_width: int (deprecated); + align_corners: bool; + half_pixel_centers: bool; +} + +table ResizeNearestNeighborOptions { + align_corners: bool; + half_pixel_centers: bool; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:uint; +} + +table PadOptions { +} + +table PadV2Options { +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SpaceToBatchNDOptions { +} + +table BatchToSpaceNDOptions { +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +table DepthToSpaceOptions { + block_size: int; +} + +table SubOptions { + fused_activation_function:ActivationFunctionType; + // Parameters supported by version 5 + pot_scale_int16:bool = true; +} + +table DivOptions { + fused_activation_function:ActivationFunctionType; +} + +table TopKV2Options { +} + +enum CombinerType : byte { + SUM = 0, + MEAN = 1, + SQRTN = 2, +} + +table EmbeddingLookupSparseOptions { + combiner:CombinerType; +} + +table GatherOptions { + axis: int; + // Parameters for Gather version 5 or above. + batch_dims: int = 0; +} + +table TransposeOptions { +} + +table ExpOptions { +} + +table CosOptions { +} + +table ReducerOptions { + keep_dims: bool; +} + +table SqueezeOptions { + squeeze_dims:[int]; +} + +table SplitOptions { + num_splits: int; +} + +table SplitVOptions { + num_splits: int; +} + +table StridedSliceOptions { + begin_mask: int; + end_mask: int; + ellipsis_mask: int; + new_axis_mask: int; + shrink_axis_mask: int; +} + +table LogSoftmaxOptions { +} + +table CastOptions { + in_data_type: TensorType; + out_data_type: TensorType; +} + +table DequantizeOptions { +} + +table MaximumMinimumOptions { +} + +table TileOptions { +} + +table ArgMaxOptions { + output_type : TensorType; +} + +table ArgMinOptions { + output_type : TensorType; +} + +table GreaterOptions { +} + +table GreaterEqualOptions { +} + +table LessOptions { +} + +table LessEqualOptions { +} + +table NegOptions { +} + +table SelectOptions { +} + +table SliceOptions { +} + +table TransposeConvOptions { + padding:Padding; + stride_w:int; + stride_h:int; +} + +table ExpandDimsOptions { +} + +table SparseToDenseOptions { + validate_indices:bool; +} + +table EqualOptions { +} + +table NotEqualOptions { +} + +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + +table RankOptions { +} + +table PowOptions { +} + +table FakeQuantOptions { + // Parameters supported by version 1: + min:float; + max:float; + num_bits:int; + + // Parameters supported by version 2: + narrow_range:bool; +} + +table PackOptions { + values_count:int; + axis:int; +} + +table LogicalOrOptions { +} + +table OneHotOptions { + axis:int; +} + +table AbsOptions { +} + + +table HardSwishOptions { +} + +table LogicalAndOptions { +} + +table LogicalNotOptions { +} + +table UnpackOptions { + num:int; + axis:int; +} + +table FloorDivOptions { +} + +table SquareOptions { +} + +table ZerosLikeOptions { +} + +table FillOptions { +} + +table FloorModOptions { +} + +table RangeOptions { +} + +table LeakyReluOptions { + alpha:float; +} + +table SquaredDifferenceOptions { +} + +// LINT.IfChange +enum MirrorPadMode : byte { + // Doesn't include borders. + REFLECT = 0, + // Includes borders. + SYMMETRIC = 1, +} +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +table MirrorPadOptions { + mode:MirrorPadMode; +} + +table UniqueOptions { + idx_out_type:TensorType = INT32; +} + +table ReverseV2Options { +} + +table AddNOptions { +} + +table GatherNdOptions { +} + +table WhereOptions { +} + +table ReverseSequenceOptions { + seq_dim:int; + batch_dim:int = 0; +} + +table MatrixDiagOptions { +} + +table QuantizeOptions { +} + +table MatrixSetDiagOptions { +} + +table IfOptions { + then_subgraph_index:int; + else_subgraph_index:int; +} + +table CallOnceOptions { + init_subgraph_index:int; +} + +table WhileOptions { + cond_subgraph_index:int; + body_subgraph_index:int; +} + +table NonMaxSuppressionV4Options { +} + +table NonMaxSuppressionV5Options { +} + +table ScatterNdOptions { +} + +table SelectV2Options { +} + +table DensifyOptions { +} + +table SegmentSumOptions { +} + +table BatchMatMulOptions { + adj_x:bool; + adj_y:bool; + // Parameters for BatchMatMul version 4 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; +} + +table CumsumOptions { + exclusive:bool; + reverse:bool; +} + +table BroadcastToOptions { +} + +table Rfft2dOptions { +} + +table HashtableOptions { + // The identity of hash tables. This identity will be used across different + // subgraphs in the same interpreter instance. + table_id:int; + key_dtype:TensorType; + value_dtype:TensorType; +} + +table HashtableFindOptions { +} + +table HashtableImportOptions { +} + +table HashtableSizeOptions { +} + +table VarHandleOptions { + container:string; + shared_name:string; +} + +table ReadVariableOptions { +} + +table AssignVariableOptions { +} + +table RandomOptions { + seed: long; + seed2: long; +} + +table BucketizeOptions { + boundaries: [float]; // The bucket boundaries. +} + +table GeluOptions { + approximate: bool; +} + +table DynamicUpdateSliceOptions { +} + +table UnsortedSegmentProdOptions { +} + +table UnsortedSegmentMaxOptions { +} + +table UnsortedSegmentSumOptions { +} + +table ATan2Options { +} + + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + // This field is for backward compatibility. This field will be used when + // the value of the extended builtin_code field has less than + // BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + deprecated_builtin_code:byte; + custom_code:string; + + // The version of the operator. The version need to be bumped whenever new + // parameters are introduced into an op. + version:int = 1; + + // This field is introduced for resolving op builtin code shortage problem + // (the original BuiltinOperator enum field was represented as a byte). + // This field will be used when the value of the extended builtin_code field + // has greater than BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + builtin_code:BuiltinOperator; +} + +enum CustomOptionsFormat : byte { + FLEXBUFFERS = 0, +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:uint; + + // Optional input are indicated by -1. + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; + custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; + + // A list of indices to the subgraph's "tensors" that are internal to an Op. + // Internal tensors are those that do not flow in or out of the operation, + // but instead are part of internal computation. As such, the operation's + // implementation may manage its memory more efficiently. They are needed + // however (i.e. not just an implementation detail) since they are part of the + // computation, which may require relevant metadata such as quantization + // parameters. + intermediates:[int]; +} + +// The root type, defining a subgraph, which typically represents an entire +// model. +table SubGraph { + // A list of all tensors used in this subgraph. + tensors:[Tensor]; + + // Indices of the tensors that are inputs into this subgraph. Note this is + // the list of non-static tensors that feed into the subgraph for inference. + inputs:[int]; + + // Indices of the tensors that are outputs out of this subgraph. Note this is + // the list of output tensors that are considered the product of the + // subgraph's inference. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of this subgraph (used for debugging). + name:string; +} + +// Table of raw data buffers (used for constant tensors). Referenced by tensors +// by index. The generous alignment accommodates mmap-friendly data structures. +table Buffer { + data:[ubyte] (force_align: 16); +} + +table Metadata { + // A human readable string to uniquely identify a Metadata. + name:string; + // An index to the buffers table. + buffer:uint; +} + +// Map from an alias name of tensor to tensor index in the graph. +// This is used in Signature def. +table TensorMap { + // Represents the alias to use for this tensor. + name:string; + + // The actual tensor index in the primary graph, that 'name' corresponds to. + tensor_index:uint; +} + +// This corresponds to SignatureDef in Tensorflow SavedModel. +// The SignatureDef will be part of the SavedModel provided for conversion. +table SignatureDef { + // Named inputs for this signature. + inputs:[TensorMap]; + + // Named outputs for this signature. + outputs:[TensorMap]; + + // Key value which was in the Tensorflow SavedModel SignatureDef map. + signature_key:string; + + // Model tag, deprecated. + deprecated_tag:string (deprecated); + + // Index of subgraphs that corresponds to the exported method. + subgraph_index:uint; +} + +table Model { + // Version of the schema. + version:uint; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; + + // Buffers of the model. + // Note the 0th entry of this array must be an empty buffer (sentinel). + // This is a convention so that tensors without a buffer can provide 0 as + // their buffer. + buffers:[Buffer]; + + // Metadata about the model. Indirects into the existings buffers list. + // Deprecated, prefer to use metadata field. + metadata_buffer:[int]; + + // Metadata about the model. + metadata:[Metadata]; + + // Optional SignatureDefs for the model. + signature_defs:[SignatureDef]; +} + +root_type Model; diff --git a/runtime/onert/core/src/loader/tflite_schema_generated.h b/runtime/onert/core/src/loader/tflite_schema_generated.h new file mode 100644 index 000000000..9d891841a --- /dev/null +++ b/runtime/onert/core/src/loader/tflite_schema_generated.h @@ -0,0 +1,11989 @@ +/* + * Copyright (c) 2019-2024 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2018 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// automatically generated by the FlatBuffers compiler, do not modify + +#ifndef FLATBUFFERS_GENERATED_TFLITESCHEMA_ONERT_TFLITE_H_ +#define FLATBUFFERS_GENERATED_TFLITESCHEMA_ONERT_TFLITE_H_ + +#include "flatbuffers/flatbuffers.h" + +// Ensure the included flatbuffers.h is the same version as when this file was +// generated, otherwise it may not be compatible. +static_assert(FLATBUFFERS_VERSION_MAJOR == 23 && FLATBUFFERS_VERSION_MINOR == 5 && + FLATBUFFERS_VERSION_REVISION == 26, + "Non-compatible flatbuffers version included"); + +namespace onert_tflite +{ + +struct CustomQuantization; +struct CustomQuantizationBuilder; + +struct QuantizationParameters; +struct QuantizationParametersBuilder; + +struct Int32Vector; +struct Int32VectorBuilder; + +struct Uint16Vector; +struct Uint16VectorBuilder; + +struct Uint8Vector; +struct Uint8VectorBuilder; + +struct DimensionMetadata; +struct DimensionMetadataBuilder; + +struct SparsityParameters; +struct SparsityParametersBuilder; + +struct Tensor; +struct TensorBuilder; + +struct Conv2DOptions; +struct Conv2DOptionsBuilder; + +struct Conv3DOptions; +struct Conv3DOptionsBuilder; + +struct Pool2DOptions; +struct Pool2DOptionsBuilder; + +struct DepthwiseConv2DOptions; +struct DepthwiseConv2DOptionsBuilder; + +struct ConcatEmbeddingsOptions; +struct ConcatEmbeddingsOptionsBuilder; + +struct LSHProjectionOptions; +struct LSHProjectionOptionsBuilder; + +struct SVDFOptions; +struct SVDFOptionsBuilder; + +struct RNNOptions; +struct RNNOptionsBuilder; + +struct SequenceRNNOptions; +struct SequenceRNNOptionsBuilder; + +struct BidirectionalSequenceRNNOptions; +struct BidirectionalSequenceRNNOptionsBuilder; + +struct FullyConnectedOptions; +struct FullyConnectedOptionsBuilder; + +struct SoftmaxOptions; +struct SoftmaxOptionsBuilder; + +struct ConcatenationOptions; +struct ConcatenationOptionsBuilder; + +struct AddOptions; +struct AddOptionsBuilder; + +struct MulOptions; +struct MulOptionsBuilder; + +struct L2NormOptions; +struct L2NormOptionsBuilder; + +struct LocalResponseNormalizationOptions; +struct LocalResponseNormalizationOptionsBuilder; + +struct LSTMOptions; +struct LSTMOptionsBuilder; + +struct UnidirectionalSequenceLSTMOptions; +struct UnidirectionalSequenceLSTMOptionsBuilder; + +struct BidirectionalSequenceLSTMOptions; +struct BidirectionalSequenceLSTMOptionsBuilder; + +struct ResizeBilinearOptions; +struct ResizeBilinearOptionsBuilder; + +struct ResizeNearestNeighborOptions; +struct ResizeNearestNeighborOptionsBuilder; + +struct CallOptions; +struct CallOptionsBuilder; + +struct PadOptions; +struct PadOptionsBuilder; + +struct PadV2Options; +struct PadV2OptionsBuilder; + +struct ReshapeOptions; +struct ReshapeOptionsBuilder; + +struct SpaceToBatchNDOptions; +struct SpaceToBatchNDOptionsBuilder; + +struct BatchToSpaceNDOptions; +struct BatchToSpaceNDOptionsBuilder; + +struct SkipGramOptions; +struct SkipGramOptionsBuilder; + +struct SpaceToDepthOptions; +struct SpaceToDepthOptionsBuilder; + +struct DepthToSpaceOptions; +struct DepthToSpaceOptionsBuilder; + +struct SubOptions; +struct SubOptionsBuilder; + +struct DivOptions; +struct DivOptionsBuilder; + +struct TopKV2Options; +struct TopKV2OptionsBuilder; + +struct EmbeddingLookupSparseOptions; +struct EmbeddingLookupSparseOptionsBuilder; + +struct GatherOptions; +struct GatherOptionsBuilder; + +struct TransposeOptions; +struct TransposeOptionsBuilder; + +struct ExpOptions; +struct ExpOptionsBuilder; + +struct CosOptions; +struct CosOptionsBuilder; + +struct ReducerOptions; +struct ReducerOptionsBuilder; + +struct SqueezeOptions; +struct SqueezeOptionsBuilder; + +struct SplitOptions; +struct SplitOptionsBuilder; + +struct SplitVOptions; +struct SplitVOptionsBuilder; + +struct StridedSliceOptions; +struct StridedSliceOptionsBuilder; + +struct LogSoftmaxOptions; +struct LogSoftmaxOptionsBuilder; + +struct CastOptions; +struct CastOptionsBuilder; + +struct DequantizeOptions; +struct DequantizeOptionsBuilder; + +struct MaximumMinimumOptions; +struct MaximumMinimumOptionsBuilder; + +struct TileOptions; +struct TileOptionsBuilder; + +struct ArgMaxOptions; +struct ArgMaxOptionsBuilder; + +struct ArgMinOptions; +struct ArgMinOptionsBuilder; + +struct GreaterOptions; +struct GreaterOptionsBuilder; + +struct GreaterEqualOptions; +struct GreaterEqualOptionsBuilder; + +struct LessOptions; +struct LessOptionsBuilder; + +struct LessEqualOptions; +struct LessEqualOptionsBuilder; + +struct NegOptions; +struct NegOptionsBuilder; + +struct SelectOptions; +struct SelectOptionsBuilder; + +struct SliceOptions; +struct SliceOptionsBuilder; + +struct TransposeConvOptions; +struct TransposeConvOptionsBuilder; + +struct ExpandDimsOptions; +struct ExpandDimsOptionsBuilder; + +struct SparseToDenseOptions; +struct SparseToDenseOptionsBuilder; + +struct EqualOptions; +struct EqualOptionsBuilder; + +struct NotEqualOptions; +struct NotEqualOptionsBuilder; + +struct ShapeOptions; +struct ShapeOptionsBuilder; + +struct RankOptions; +struct RankOptionsBuilder; + +struct PowOptions; +struct PowOptionsBuilder; + +struct FakeQuantOptions; +struct FakeQuantOptionsBuilder; + +struct PackOptions; +struct PackOptionsBuilder; + +struct LogicalOrOptions; +struct LogicalOrOptionsBuilder; + +struct OneHotOptions; +struct OneHotOptionsBuilder; + +struct AbsOptions; +struct AbsOptionsBuilder; + +struct HardSwishOptions; +struct HardSwishOptionsBuilder; + +struct LogicalAndOptions; +struct LogicalAndOptionsBuilder; + +struct LogicalNotOptions; +struct LogicalNotOptionsBuilder; + +struct UnpackOptions; +struct UnpackOptionsBuilder; + +struct FloorDivOptions; +struct FloorDivOptionsBuilder; + +struct SquareOptions; +struct SquareOptionsBuilder; + +struct ZerosLikeOptions; +struct ZerosLikeOptionsBuilder; + +struct FillOptions; +struct FillOptionsBuilder; + +struct FloorModOptions; +struct FloorModOptionsBuilder; + +struct RangeOptions; +struct RangeOptionsBuilder; + +struct LeakyReluOptions; +struct LeakyReluOptionsBuilder; + +struct SquaredDifferenceOptions; +struct SquaredDifferenceOptionsBuilder; + +struct MirrorPadOptions; +struct MirrorPadOptionsBuilder; + +struct UniqueOptions; +struct UniqueOptionsBuilder; + +struct ReverseV2Options; +struct ReverseV2OptionsBuilder; + +struct AddNOptions; +struct AddNOptionsBuilder; + +struct GatherNdOptions; +struct GatherNdOptionsBuilder; + +struct WhereOptions; +struct WhereOptionsBuilder; + +struct ReverseSequenceOptions; +struct ReverseSequenceOptionsBuilder; + +struct MatrixDiagOptions; +struct MatrixDiagOptionsBuilder; + +struct QuantizeOptions; +struct QuantizeOptionsBuilder; + +struct MatrixSetDiagOptions; +struct MatrixSetDiagOptionsBuilder; + +struct IfOptions; +struct IfOptionsBuilder; + +struct CallOnceOptions; +struct CallOnceOptionsBuilder; + +struct WhileOptions; +struct WhileOptionsBuilder; + +struct NonMaxSuppressionV4Options; +struct NonMaxSuppressionV4OptionsBuilder; + +struct NonMaxSuppressionV5Options; +struct NonMaxSuppressionV5OptionsBuilder; + +struct ScatterNdOptions; +struct ScatterNdOptionsBuilder; + +struct SelectV2Options; +struct SelectV2OptionsBuilder; + +struct DensifyOptions; +struct DensifyOptionsBuilder; + +struct SegmentSumOptions; +struct SegmentSumOptionsBuilder; + +struct BatchMatMulOptions; +struct BatchMatMulOptionsBuilder; + +struct CumsumOptions; +struct CumsumOptionsBuilder; + +struct BroadcastToOptions; +struct BroadcastToOptionsBuilder; + +struct Rfft2dOptions; +struct Rfft2dOptionsBuilder; + +struct HashtableOptions; +struct HashtableOptionsBuilder; + +struct HashtableFindOptions; +struct HashtableFindOptionsBuilder; + +struct HashtableImportOptions; +struct HashtableImportOptionsBuilder; + +struct HashtableSizeOptions; +struct HashtableSizeOptionsBuilder; + +struct VarHandleOptions; +struct VarHandleOptionsBuilder; + +struct ReadVariableOptions; +struct ReadVariableOptionsBuilder; + +struct AssignVariableOptions; +struct AssignVariableOptionsBuilder; + +struct RandomOptions; +struct RandomOptionsBuilder; + +struct BucketizeOptions; +struct BucketizeOptionsBuilder; + +struct GeluOptions; +struct GeluOptionsBuilder; + +struct DynamicUpdateSliceOptions; +struct DynamicUpdateSliceOptionsBuilder; + +struct UnsortedSegmentProdOptions; +struct UnsortedSegmentProdOptionsBuilder; + +struct UnsortedSegmentMaxOptions; +struct UnsortedSegmentMaxOptionsBuilder; + +struct UnsortedSegmentSumOptions; +struct UnsortedSegmentSumOptionsBuilder; + +struct ATan2Options; +struct ATan2OptionsBuilder; + +struct OperatorCode; +struct OperatorCodeBuilder; + +struct Operator; +struct OperatorBuilder; + +struct SubGraph; +struct SubGraphBuilder; + +struct Buffer; +struct BufferBuilder; + +struct Metadata; +struct MetadataBuilder; + +struct TensorMap; +struct TensorMapBuilder; + +struct SignatureDef; +struct SignatureDefBuilder; + +struct Model; +struct ModelBuilder; + +enum TensorType : int8_t +{ + TensorType_FLOAT32 = 0, + TensorType_FLOAT16 = 1, + TensorType_INT32 = 2, + TensorType_UINT8 = 3, + TensorType_INT64 = 4, + TensorType_STRING = 5, + TensorType_BOOL = 6, + TensorType_INT16 = 7, + TensorType_COMPLEX64 = 8, + TensorType_INT8 = 9, + TensorType_FLOAT64 = 10, + TensorType_COMPLEX128 = 11, + TensorType_UINT64 = 12, + TensorType_RESOURCE = 13, + TensorType_VARIANT = 14, + TensorType_UINT32 = 15, + TensorType_UINT16 = 16, + TensorType_MIN = TensorType_FLOAT32, + TensorType_MAX = TensorType_UINT16 +}; + +inline const TensorType (&EnumValuesTensorType())[17] +{ + static const TensorType values[] = { + TensorType_FLOAT32, TensorType_FLOAT16, TensorType_INT32, TensorType_UINT8, + TensorType_INT64, TensorType_STRING, TensorType_BOOL, TensorType_INT16, + TensorType_COMPLEX64, TensorType_INT8, TensorType_FLOAT64, TensorType_COMPLEX128, + TensorType_UINT64, TensorType_RESOURCE, TensorType_VARIANT, TensorType_UINT32, + TensorType_UINT16}; + return values; +} + +inline const char *const *EnumNamesTensorType() +{ + static const char *const names[18] = {"FLOAT32", "FLOAT16", "INT32", "UINT8", "INT64", + "STRING", "BOOL", "INT16", "COMPLEX64", "INT8", + "FLOAT64", "COMPLEX128", "UINT64", "RESOURCE", "VARIANT", + "UINT32", "UINT16", nullptr}; + return names; +} + +inline const char *EnumNameTensorType(TensorType e) +{ + if (::flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_UINT16)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesTensorType()[index]; +} + +enum QuantizationDetails : uint8_t +{ + QuantizationDetails_NONE = 0, + QuantizationDetails_CustomQuantization = 1, + QuantizationDetails_MIN = QuantizationDetails_NONE, + QuantizationDetails_MAX = QuantizationDetails_CustomQuantization +}; + +inline const QuantizationDetails (&EnumValuesQuantizationDetails())[2] +{ + static const QuantizationDetails values[] = {QuantizationDetails_NONE, + QuantizationDetails_CustomQuantization}; + return values; +} + +inline const char *const *EnumNamesQuantizationDetails() +{ + static const char *const names[3] = {"NONE", "CustomQuantization", nullptr}; + return names; +} + +inline const char *EnumNameQuantizationDetails(QuantizationDetails e) +{ + if (::flatbuffers::IsOutRange(e, QuantizationDetails_NONE, + QuantizationDetails_CustomQuantization)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesQuantizationDetails()[index]; +} + +template <typename T> struct QuantizationDetailsTraits +{ + static const QuantizationDetails enum_value = QuantizationDetails_NONE; +}; + +template <> struct QuantizationDetailsTraits<onert_tflite::CustomQuantization> +{ + static const QuantizationDetails enum_value = QuantizationDetails_CustomQuantization; +}; + +bool VerifyQuantizationDetails(::flatbuffers::Verifier &verifier, const void *obj, + QuantizationDetails type); +bool VerifyQuantizationDetailsVector( + ::flatbuffers::Verifier &verifier, + const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, + const ::flatbuffers::Vector<uint8_t> *types); + +enum DimensionType : int8_t +{ + DimensionType_DENSE = 0, + DimensionType_SPARSE_CSR = 1, + DimensionType_MIN = DimensionType_DENSE, + DimensionType_MAX = DimensionType_SPARSE_CSR +}; + +inline const DimensionType (&EnumValuesDimensionType())[2] +{ + static const DimensionType values[] = {DimensionType_DENSE, DimensionType_SPARSE_CSR}; + return values; +} + +inline const char *const *EnumNamesDimensionType() +{ + static const char *const names[3] = {"DENSE", "SPARSE_CSR", nullptr}; + return names; +} + +inline const char *EnumNameDimensionType(DimensionType e) +{ + if (::flatbuffers::IsOutRange(e, DimensionType_DENSE, DimensionType_SPARSE_CSR)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesDimensionType()[index]; +} + +enum SparseIndexVector : uint8_t +{ + SparseIndexVector_NONE = 0, + SparseIndexVector_Int32Vector = 1, + SparseIndexVector_Uint16Vector = 2, + SparseIndexVector_Uint8Vector = 3, + SparseIndexVector_MIN = SparseIndexVector_NONE, + SparseIndexVector_MAX = SparseIndexVector_Uint8Vector +}; + +inline const SparseIndexVector (&EnumValuesSparseIndexVector())[4] +{ + static const SparseIndexVector values[] = {SparseIndexVector_NONE, SparseIndexVector_Int32Vector, + SparseIndexVector_Uint16Vector, + SparseIndexVector_Uint8Vector}; + return values; +} + +inline const char *const *EnumNamesSparseIndexVector() +{ + static const char *const names[5] = {"NONE", "Int32Vector", "Uint16Vector", "Uint8Vector", + nullptr}; + return names; +} + +inline const char *EnumNameSparseIndexVector(SparseIndexVector e) +{ + if (::flatbuffers::IsOutRange(e, SparseIndexVector_NONE, SparseIndexVector_Uint8Vector)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesSparseIndexVector()[index]; +} + +template <typename T> struct SparseIndexVectorTraits +{ + static const SparseIndexVector enum_value = SparseIndexVector_NONE; +}; + +template <> struct SparseIndexVectorTraits<onert_tflite::Int32Vector> +{ + static const SparseIndexVector enum_value = SparseIndexVector_Int32Vector; +}; + +template <> struct SparseIndexVectorTraits<onert_tflite::Uint16Vector> +{ + static const SparseIndexVector enum_value = SparseIndexVector_Uint16Vector; +}; + +template <> struct SparseIndexVectorTraits<onert_tflite::Uint8Vector> +{ + static const SparseIndexVector enum_value = SparseIndexVector_Uint8Vector; +}; + +bool VerifySparseIndexVector(::flatbuffers::Verifier &verifier, const void *obj, + SparseIndexVector type); +bool VerifySparseIndexVectorVector(::flatbuffers::Verifier &verifier, + const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, + const ::flatbuffers::Vector<uint8_t> *types); + +enum BuiltinOperator : int32_t +{ + BuiltinOperator_ADD = 0, + BuiltinOperator_AVERAGE_POOL_2D = 1, + BuiltinOperator_CONCATENATION = 2, + BuiltinOperator_CONV_2D = 3, + BuiltinOperator_DEPTHWISE_CONV_2D = 4, + BuiltinOperator_DEPTH_TO_SPACE = 5, + BuiltinOperator_DEQUANTIZE = 6, + BuiltinOperator_EMBEDDING_LOOKUP = 7, + BuiltinOperator_FLOOR = 8, + BuiltinOperator_FULLY_CONNECTED = 9, + BuiltinOperator_HASHTABLE_LOOKUP = 10, + BuiltinOperator_L2_NORMALIZATION = 11, + BuiltinOperator_L2_POOL_2D = 12, + BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION = 13, + BuiltinOperator_LOGISTIC = 14, + BuiltinOperator_LSH_PROJECTION = 15, + BuiltinOperator_LSTM = 16, + BuiltinOperator_MAX_POOL_2D = 17, + BuiltinOperator_MUL = 18, + BuiltinOperator_RELU = 19, + BuiltinOperator_RELU_N1_TO_1 = 20, + BuiltinOperator_RELU6 = 21, + BuiltinOperator_RESHAPE = 22, + BuiltinOperator_RESIZE_BILINEAR = 23, + BuiltinOperator_RNN = 24, + BuiltinOperator_SOFTMAX = 25, + BuiltinOperator_SPACE_TO_DEPTH = 26, + BuiltinOperator_SVDF = 27, + BuiltinOperator_TANH = 28, + BuiltinOperator_CONCAT_EMBEDDINGS = 29, + BuiltinOperator_SKIP_GRAM = 30, + BuiltinOperator_CALL = 31, + BuiltinOperator_CUSTOM = 32, + BuiltinOperator_EMBEDDING_LOOKUP_SPARSE = 33, + BuiltinOperator_PAD = 34, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN = 35, + BuiltinOperator_GATHER = 36, + BuiltinOperator_BATCH_TO_SPACE_ND = 37, + BuiltinOperator_SPACE_TO_BATCH_ND = 38, + BuiltinOperator_TRANSPOSE = 39, + BuiltinOperator_MEAN = 40, + BuiltinOperator_SUB = 41, + BuiltinOperator_DIV = 42, + BuiltinOperator_SQUEEZE = 43, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + BuiltinOperator_STRIDED_SLICE = 45, + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN = 46, + BuiltinOperator_EXP = 47, + BuiltinOperator_TOPK_V2 = 48, + BuiltinOperator_SPLIT = 49, + BuiltinOperator_LOG_SOFTMAX = 50, + BuiltinOperator_DELEGATE = 51, + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM = 52, + BuiltinOperator_CAST = 53, + BuiltinOperator_PRELU = 54, + BuiltinOperator_MAXIMUM = 55, + BuiltinOperator_ARG_MAX = 56, + BuiltinOperator_MINIMUM = 57, + BuiltinOperator_LESS = 58, + BuiltinOperator_NEG = 59, + BuiltinOperator_PADV2 = 60, + BuiltinOperator_GREATER = 61, + BuiltinOperator_GREATER_EQUAL = 62, + BuiltinOperator_LESS_EQUAL = 63, + BuiltinOperator_SELECT = 64, + BuiltinOperator_SLICE = 65, + BuiltinOperator_SIN = 66, + BuiltinOperator_TRANSPOSE_CONV = 67, + BuiltinOperator_SPARSE_TO_DENSE = 68, + BuiltinOperator_TILE = 69, + BuiltinOperator_EXPAND_DIMS = 70, + BuiltinOperator_EQUAL = 71, + BuiltinOperator_NOT_EQUAL = 72, + BuiltinOperator_LOG = 73, + BuiltinOperator_SUM = 74, + BuiltinOperator_SQRT = 75, + BuiltinOperator_RSQRT = 76, + BuiltinOperator_SHAPE = 77, + BuiltinOperator_POW = 78, + BuiltinOperator_ARG_MIN = 79, + BuiltinOperator_FAKE_QUANT = 80, + BuiltinOperator_REDUCE_PROD = 81, + BuiltinOperator_REDUCE_MAX = 82, + BuiltinOperator_PACK = 83, + BuiltinOperator_LOGICAL_OR = 84, + BuiltinOperator_ONE_HOT = 85, + BuiltinOperator_LOGICAL_AND = 86, + BuiltinOperator_LOGICAL_NOT = 87, + BuiltinOperator_UNPACK = 88, + BuiltinOperator_REDUCE_MIN = 89, + BuiltinOperator_FLOOR_DIV = 90, + BuiltinOperator_REDUCE_ANY = 91, + BuiltinOperator_SQUARE = 92, + BuiltinOperator_ZEROS_LIKE = 93, + BuiltinOperator_FILL = 94, + BuiltinOperator_FLOOR_MOD = 95, + BuiltinOperator_RANGE = 96, + BuiltinOperator_RESIZE_NEAREST_NEIGHBOR = 97, + BuiltinOperator_LEAKY_RELU = 98, + BuiltinOperator_SQUARED_DIFFERENCE = 99, + BuiltinOperator_MIRROR_PAD = 100, + BuiltinOperator_ABS = 101, + BuiltinOperator_SPLIT_V = 102, + BuiltinOperator_UNIQUE = 103, + BuiltinOperator_CEIL = 104, + BuiltinOperator_REVERSE_V2 = 105, + BuiltinOperator_ADD_N = 106, + BuiltinOperator_GATHER_ND = 107, + BuiltinOperator_COS = 108, + BuiltinOperator_WHERE = 109, + BuiltinOperator_RANK = 110, + BuiltinOperator_ELU = 111, + BuiltinOperator_REVERSE_SEQUENCE = 112, + BuiltinOperator_MATRIX_DIAG = 113, + BuiltinOperator_QUANTIZE = 114, + BuiltinOperator_MATRIX_SET_DIAG = 115, + BuiltinOperator_ROUND = 116, + BuiltinOperator_HARD_SWISH = 117, + BuiltinOperator_IF = 118, + BuiltinOperator_WHILE = 119, + BuiltinOperator_NON_MAX_SUPPRESSION_V4 = 120, + BuiltinOperator_NON_MAX_SUPPRESSION_V5 = 121, + BuiltinOperator_SCATTER_ND = 122, + BuiltinOperator_SELECT_V2 = 123, + BuiltinOperator_DENSIFY = 124, + BuiltinOperator_SEGMENT_SUM = 125, + BuiltinOperator_BATCH_MATMUL = 126, + BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127, + BuiltinOperator_CUMSUM = 128, + BuiltinOperator_CALL_ONCE = 129, + BuiltinOperator_BROADCAST_TO = 130, + BuiltinOperator_RFFT2D = 131, + BuiltinOperator_CONV_3D = 132, + BuiltinOperator_IMAG = 133, + BuiltinOperator_REAL = 134, + BuiltinOperator_COMPLEX_ABS = 135, + BuiltinOperator_HASHTABLE = 136, + BuiltinOperator_HASHTABLE_FIND = 137, + BuiltinOperator_HASHTABLE_IMPORT = 138, + BuiltinOperator_HASHTABLE_SIZE = 139, + BuiltinOperator_REDUCE_ALL = 140, + BuiltinOperator_CONV_3D_TRANSPOSE = 141, + BuiltinOperator_VAR_HANDLE = 142, + BuiltinOperator_READ_VARIABLE = 143, + BuiltinOperator_ASSIGN_VARIABLE = 144, + BuiltinOperator_BROADCAST_ARGS = 145, + BuiltinOperator_RANDOM_STANDARD_NORMAL = 146, + BuiltinOperator_BUCKETIZE = 147, + BuiltinOperator_RANDOM_UNIFORM = 148, + BuiltinOperator_MULTINOMIAL = 149, + BuiltinOperator_GELU = 150, + BuiltinOperator_DYNAMIC_UPDATE_SLICE = 151, + BuiltinOperator_RELU_0_TO_1 = 152, + BuiltinOperator_UNSORTED_SEGMENT_PROD = 153, + BuiltinOperator_UNSORTED_SEGMENT_MAX = 154, + BuiltinOperator_UNSORTED_SEGMENT_SUM = 155, + BuiltinOperator_ATAN2 = 156, + BuiltinOperator_MIN = BuiltinOperator_ADD, + BuiltinOperator_MAX = BuiltinOperator_ATAN2 +}; + +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[157] +{ + static const BuiltinOperator values[] = {BuiltinOperator_ADD, + BuiltinOperator_AVERAGE_POOL_2D, + BuiltinOperator_CONCATENATION, + BuiltinOperator_CONV_2D, + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOperator_DEPTH_TO_SPACE, + BuiltinOperator_DEQUANTIZE, + BuiltinOperator_EMBEDDING_LOOKUP, + BuiltinOperator_FLOOR, + BuiltinOperator_FULLY_CONNECTED, + BuiltinOperator_HASHTABLE_LOOKUP, + BuiltinOperator_L2_NORMALIZATION, + BuiltinOperator_L2_POOL_2D, + BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + BuiltinOperator_LOGISTIC, + BuiltinOperator_LSH_PROJECTION, + BuiltinOperator_LSTM, + BuiltinOperator_MAX_POOL_2D, + BuiltinOperator_MUL, + BuiltinOperator_RELU, + BuiltinOperator_RELU_N1_TO_1, + BuiltinOperator_RELU6, + BuiltinOperator_RESHAPE, + BuiltinOperator_RESIZE_BILINEAR, + BuiltinOperator_RNN, + BuiltinOperator_SOFTMAX, + BuiltinOperator_SPACE_TO_DEPTH, + BuiltinOperator_SVDF, + BuiltinOperator_TANH, + BuiltinOperator_CONCAT_EMBEDDINGS, + BuiltinOperator_SKIP_GRAM, + BuiltinOperator_CALL, + BuiltinOperator_CUSTOM, + BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, + BuiltinOperator_PAD, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOperator_GATHER, + BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOperator_TRANSPOSE, + BuiltinOperator_MEAN, + BuiltinOperator_SUB, + BuiltinOperator_DIV, + BuiltinOperator_SQUEEZE, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOperator_STRIDED_SLICE, + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOperator_EXP, + BuiltinOperator_TOPK_V2, + BuiltinOperator_SPLIT, + BuiltinOperator_LOG_SOFTMAX, + BuiltinOperator_DELEGATE, + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOperator_CAST, + BuiltinOperator_PRELU, + BuiltinOperator_MAXIMUM, + BuiltinOperator_ARG_MAX, + BuiltinOperator_MINIMUM, + BuiltinOperator_LESS, + BuiltinOperator_NEG, + BuiltinOperator_PADV2, + BuiltinOperator_GREATER, + BuiltinOperator_GREATER_EQUAL, + BuiltinOperator_LESS_EQUAL, + BuiltinOperator_SELECT, + BuiltinOperator_SLICE, + BuiltinOperator_SIN, + BuiltinOperator_TRANSPOSE_CONV, + BuiltinOperator_SPARSE_TO_DENSE, + BuiltinOperator_TILE, + BuiltinOperator_EXPAND_DIMS, + BuiltinOperator_EQUAL, + BuiltinOperator_NOT_EQUAL, + BuiltinOperator_LOG, + BuiltinOperator_SUM, + BuiltinOperator_SQRT, + BuiltinOperator_RSQRT, + BuiltinOperator_SHAPE, + BuiltinOperator_POW, + BuiltinOperator_ARG_MIN, + BuiltinOperator_FAKE_QUANT, + BuiltinOperator_REDUCE_PROD, + BuiltinOperator_REDUCE_MAX, + BuiltinOperator_PACK, + BuiltinOperator_LOGICAL_OR, + BuiltinOperator_ONE_HOT, + BuiltinOperator_LOGICAL_AND, + BuiltinOperator_LOGICAL_NOT, + BuiltinOperator_UNPACK, + BuiltinOperator_REDUCE_MIN, + BuiltinOperator_FLOOR_DIV, + BuiltinOperator_REDUCE_ANY, + BuiltinOperator_SQUARE, + BuiltinOperator_ZEROS_LIKE, + BuiltinOperator_FILL, + BuiltinOperator_FLOOR_MOD, + BuiltinOperator_RANGE, + BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, + BuiltinOperator_LEAKY_RELU, + BuiltinOperator_SQUARED_DIFFERENCE, + BuiltinOperator_MIRROR_PAD, + BuiltinOperator_ABS, + BuiltinOperator_SPLIT_V, + BuiltinOperator_UNIQUE, + BuiltinOperator_CEIL, + BuiltinOperator_REVERSE_V2, + BuiltinOperator_ADD_N, + BuiltinOperator_GATHER_ND, + BuiltinOperator_COS, + BuiltinOperator_WHERE, + BuiltinOperator_RANK, + BuiltinOperator_ELU, + BuiltinOperator_REVERSE_SEQUENCE, + BuiltinOperator_MATRIX_DIAG, + BuiltinOperator_QUANTIZE, + BuiltinOperator_MATRIX_SET_DIAG, + BuiltinOperator_ROUND, + BuiltinOperator_HARD_SWISH, + BuiltinOperator_IF, + BuiltinOperator_WHILE, + BuiltinOperator_NON_MAX_SUPPRESSION_V4, + BuiltinOperator_NON_MAX_SUPPRESSION_V5, + BuiltinOperator_SCATTER_ND, + BuiltinOperator_SELECT_V2, + BuiltinOperator_DENSIFY, + BuiltinOperator_SEGMENT_SUM, + BuiltinOperator_BATCH_MATMUL, + BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES, + BuiltinOperator_CUMSUM, + BuiltinOperator_CALL_ONCE, + BuiltinOperator_BROADCAST_TO, + BuiltinOperator_RFFT2D, + BuiltinOperator_CONV_3D, + BuiltinOperator_IMAG, + BuiltinOperator_REAL, + BuiltinOperator_COMPLEX_ABS, + BuiltinOperator_HASHTABLE, + BuiltinOperator_HASHTABLE_FIND, + BuiltinOperator_HASHTABLE_IMPORT, + BuiltinOperator_HASHTABLE_SIZE, + BuiltinOperator_REDUCE_ALL, + BuiltinOperator_CONV_3D_TRANSPOSE, + BuiltinOperator_VAR_HANDLE, + BuiltinOperator_READ_VARIABLE, + BuiltinOperator_ASSIGN_VARIABLE, + BuiltinOperator_BROADCAST_ARGS, + BuiltinOperator_RANDOM_STANDARD_NORMAL, + BuiltinOperator_BUCKETIZE, + BuiltinOperator_RANDOM_UNIFORM, + BuiltinOperator_MULTINOMIAL, + BuiltinOperator_GELU, + BuiltinOperator_DYNAMIC_UPDATE_SLICE, + BuiltinOperator_RELU_0_TO_1, + BuiltinOperator_UNSORTED_SEGMENT_PROD, + BuiltinOperator_UNSORTED_SEGMENT_MAX, + BuiltinOperator_UNSORTED_SEGMENT_SUM, + BuiltinOperator_ATAN2}; + return values; +} + +inline const char *const *EnumNamesBuiltinOperator() +{ + static const char *const names[158] = {"ADD", + "AVERAGE_POOL_2D", + "CONCATENATION", + "CONV_2D", + "DEPTHWISE_CONV_2D", + "DEPTH_TO_SPACE", + "DEQUANTIZE", + "EMBEDDING_LOOKUP", + "FLOOR", + "FULLY_CONNECTED", + "HASHTABLE_LOOKUP", + "L2_NORMALIZATION", + "L2_POOL_2D", + "LOCAL_RESPONSE_NORMALIZATION", + "LOGISTIC", + "LSH_PROJECTION", + "LSTM", + "MAX_POOL_2D", + "MUL", + "RELU", + "RELU_N1_TO_1", + "RELU6", + "RESHAPE", + "RESIZE_BILINEAR", + "RNN", + "SOFTMAX", + "SPACE_TO_DEPTH", + "SVDF", + "TANH", + "CONCAT_EMBEDDINGS", + "SKIP_GRAM", + "CALL", + "CUSTOM", + "EMBEDDING_LOOKUP_SPARSE", + "PAD", + "UNIDIRECTIONAL_SEQUENCE_RNN", + "GATHER", + "BATCH_TO_SPACE_ND", + "SPACE_TO_BATCH_ND", + "TRANSPOSE", + "MEAN", + "SUB", + "DIV", + "SQUEEZE", + "UNIDIRECTIONAL_SEQUENCE_LSTM", + "STRIDED_SLICE", + "BIDIRECTIONAL_SEQUENCE_RNN", + "EXP", + "TOPK_V2", + "SPLIT", + "LOG_SOFTMAX", + "DELEGATE", + "BIDIRECTIONAL_SEQUENCE_LSTM", + "CAST", + "PRELU", + "MAXIMUM", + "ARG_MAX", + "MINIMUM", + "LESS", + "NEG", + "PADV2", + "GREATER", + "GREATER_EQUAL", + "LESS_EQUAL", + "SELECT", + "SLICE", + "SIN", + "TRANSPOSE_CONV", + "SPARSE_TO_DENSE", + "TILE", + "EXPAND_DIMS", + "EQUAL", + "NOT_EQUAL", + "LOG", + "SUM", + "SQRT", + "RSQRT", + "SHAPE", + "POW", + "ARG_MIN", + "FAKE_QUANT", + "REDUCE_PROD", + "REDUCE_MAX", + "PACK", + "LOGICAL_OR", + "ONE_HOT", + "LOGICAL_AND", + "LOGICAL_NOT", + "UNPACK", + "REDUCE_MIN", + "FLOOR_DIV", + "REDUCE_ANY", + "SQUARE", + "ZEROS_LIKE", + "FILL", + "FLOOR_MOD", + "RANGE", + "RESIZE_NEAREST_NEIGHBOR", + "LEAKY_RELU", + "SQUARED_DIFFERENCE", + "MIRROR_PAD", + "ABS", + "SPLIT_V", + "UNIQUE", + "CEIL", + "REVERSE_V2", + "ADD_N", + "GATHER_ND", + "COS", + "WHERE", + "RANK", + "ELU", + "REVERSE_SEQUENCE", + "MATRIX_DIAG", + "QUANTIZE", + "MATRIX_SET_DIAG", + "ROUND", + "HARD_SWISH", + "IF", + "WHILE", + "NON_MAX_SUPPRESSION_V4", + "NON_MAX_SUPPRESSION_V5", + "SCATTER_ND", + "SELECT_V2", + "DENSIFY", + "SEGMENT_SUM", + "BATCH_MATMUL", + "PLACEHOLDER_FOR_GREATER_OP_CODES", + "CUMSUM", + "CALL_ONCE", + "BROADCAST_TO", + "RFFT2D", + "CONV_3D", + "IMAG", + "REAL", + "COMPLEX_ABS", + "HASHTABLE", + "HASHTABLE_FIND", + "HASHTABLE_IMPORT", + "HASHTABLE_SIZE", + "REDUCE_ALL", + "CONV_3D_TRANSPOSE", + "VAR_HANDLE", + "READ_VARIABLE", + "ASSIGN_VARIABLE", + "BROADCAST_ARGS", + "RANDOM_STANDARD_NORMAL", + "BUCKETIZE", + "RANDOM_UNIFORM", + "MULTINOMIAL", + "GELU", + "DYNAMIC_UPDATE_SLICE", + "RELU_0_TO_1", + "UNSORTED_SEGMENT_PROD", + "UNSORTED_SEGMENT_MAX", + "UNSORTED_SEGMENT_SUM", + "ATAN2", + nullptr}; + return names; +} + +inline const char *EnumNameBuiltinOperator(BuiltinOperator e) +{ + if (::flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_ATAN2)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesBuiltinOperator()[index]; +} + +enum BuiltinOptions : uint8_t +{ + BuiltinOptions_NONE = 0, + BuiltinOptions_Conv2DOptions = 1, + BuiltinOptions_DepthwiseConv2DOptions = 2, + BuiltinOptions_ConcatEmbeddingsOptions = 3, + BuiltinOptions_LSHProjectionOptions = 4, + BuiltinOptions_Pool2DOptions = 5, + BuiltinOptions_SVDFOptions = 6, + BuiltinOptions_RNNOptions = 7, + BuiltinOptions_FullyConnectedOptions = 8, + BuiltinOptions_SoftmaxOptions = 9, + BuiltinOptions_ConcatenationOptions = 10, + BuiltinOptions_AddOptions = 11, + BuiltinOptions_L2NormOptions = 12, + BuiltinOptions_LocalResponseNormalizationOptions = 13, + BuiltinOptions_LSTMOptions = 14, + BuiltinOptions_ResizeBilinearOptions = 15, + BuiltinOptions_CallOptions = 16, + BuiltinOptions_ReshapeOptions = 17, + BuiltinOptions_SkipGramOptions = 18, + BuiltinOptions_SpaceToDepthOptions = 19, + BuiltinOptions_EmbeddingLookupSparseOptions = 20, + BuiltinOptions_MulOptions = 21, + BuiltinOptions_PadOptions = 22, + BuiltinOptions_GatherOptions = 23, + BuiltinOptions_BatchToSpaceNDOptions = 24, + BuiltinOptions_SpaceToBatchNDOptions = 25, + BuiltinOptions_TransposeOptions = 26, + BuiltinOptions_ReducerOptions = 27, + BuiltinOptions_SubOptions = 28, + BuiltinOptions_DivOptions = 29, + BuiltinOptions_SqueezeOptions = 30, + BuiltinOptions_SequenceRNNOptions = 31, + BuiltinOptions_StridedSliceOptions = 32, + BuiltinOptions_ExpOptions = 33, + BuiltinOptions_TopKV2Options = 34, + BuiltinOptions_SplitOptions = 35, + BuiltinOptions_LogSoftmaxOptions = 36, + BuiltinOptions_CastOptions = 37, + BuiltinOptions_DequantizeOptions = 38, + BuiltinOptions_MaximumMinimumOptions = 39, + BuiltinOptions_ArgMaxOptions = 40, + BuiltinOptions_LessOptions = 41, + BuiltinOptions_NegOptions = 42, + BuiltinOptions_PadV2Options = 43, + BuiltinOptions_GreaterOptions = 44, + BuiltinOptions_GreaterEqualOptions = 45, + BuiltinOptions_LessEqualOptions = 46, + BuiltinOptions_SelectOptions = 47, + BuiltinOptions_SliceOptions = 48, + BuiltinOptions_TransposeConvOptions = 49, + BuiltinOptions_SparseToDenseOptions = 50, + BuiltinOptions_TileOptions = 51, + BuiltinOptions_ExpandDimsOptions = 52, + BuiltinOptions_EqualOptions = 53, + BuiltinOptions_NotEqualOptions = 54, + BuiltinOptions_ShapeOptions = 55, + BuiltinOptions_PowOptions = 56, + BuiltinOptions_ArgMinOptions = 57, + BuiltinOptions_FakeQuantOptions = 58, + BuiltinOptions_PackOptions = 59, + BuiltinOptions_LogicalOrOptions = 60, + BuiltinOptions_OneHotOptions = 61, + BuiltinOptions_LogicalAndOptions = 62, + BuiltinOptions_LogicalNotOptions = 63, + BuiltinOptions_UnpackOptions = 64, + BuiltinOptions_FloorDivOptions = 65, + BuiltinOptions_SquareOptions = 66, + BuiltinOptions_ZerosLikeOptions = 67, + BuiltinOptions_FillOptions = 68, + BuiltinOptions_BidirectionalSequenceLSTMOptions = 69, + BuiltinOptions_BidirectionalSequenceRNNOptions = 70, + BuiltinOptions_UnidirectionalSequenceLSTMOptions = 71, + BuiltinOptions_FloorModOptions = 72, + BuiltinOptions_RangeOptions = 73, + BuiltinOptions_ResizeNearestNeighborOptions = 74, + BuiltinOptions_LeakyReluOptions = 75, + BuiltinOptions_SquaredDifferenceOptions = 76, + BuiltinOptions_MirrorPadOptions = 77, + BuiltinOptions_AbsOptions = 78, + BuiltinOptions_SplitVOptions = 79, + BuiltinOptions_UniqueOptions = 80, + BuiltinOptions_ReverseV2Options = 81, + BuiltinOptions_AddNOptions = 82, + BuiltinOptions_GatherNdOptions = 83, + BuiltinOptions_CosOptions = 84, + BuiltinOptions_WhereOptions = 85, + BuiltinOptions_RankOptions = 86, + BuiltinOptions_ReverseSequenceOptions = 87, + BuiltinOptions_MatrixDiagOptions = 88, + BuiltinOptions_QuantizeOptions = 89, + BuiltinOptions_MatrixSetDiagOptions = 90, + BuiltinOptions_HardSwishOptions = 91, + BuiltinOptions_IfOptions = 92, + BuiltinOptions_WhileOptions = 93, + BuiltinOptions_DepthToSpaceOptions = 94, + BuiltinOptions_NonMaxSuppressionV4Options = 95, + BuiltinOptions_NonMaxSuppressionV5Options = 96, + BuiltinOptions_ScatterNdOptions = 97, + BuiltinOptions_SelectV2Options = 98, + BuiltinOptions_DensifyOptions = 99, + BuiltinOptions_SegmentSumOptions = 100, + BuiltinOptions_BatchMatMulOptions = 101, + BuiltinOptions_CumsumOptions = 102, + BuiltinOptions_CallOnceOptions = 103, + BuiltinOptions_BroadcastToOptions = 104, + BuiltinOptions_Rfft2dOptions = 105, + BuiltinOptions_Conv3DOptions = 106, + BuiltinOptions_HashtableOptions = 107, + BuiltinOptions_HashtableFindOptions = 108, + BuiltinOptions_HashtableImportOptions = 109, + BuiltinOptions_HashtableSizeOptions = 110, + BuiltinOptions_VarHandleOptions = 111, + BuiltinOptions_ReadVariableOptions = 112, + BuiltinOptions_AssignVariableOptions = 113, + BuiltinOptions_RandomOptions = 114, + BuiltinOptions_BucketizeOptions = 115, + BuiltinOptions_GeluOptions = 116, + BuiltinOptions_DynamicUpdateSliceOptions = 117, + BuiltinOptions_UnsortedSegmentProdOptions = 118, + BuiltinOptions_UnsortedSegmentMaxOptions = 119, + BuiltinOptions_UnsortedSegmentSumOptions = 120, + BuiltinOptions_ATan2Options = 121, + BuiltinOptions_MIN = BuiltinOptions_NONE, + BuiltinOptions_MAX = BuiltinOptions_ATan2Options +}; + +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[122] +{ + static const BuiltinOptions values[] = {BuiltinOptions_NONE, + BuiltinOptions_Conv2DOptions, + BuiltinOptions_DepthwiseConv2DOptions, + BuiltinOptions_ConcatEmbeddingsOptions, + BuiltinOptions_LSHProjectionOptions, + BuiltinOptions_Pool2DOptions, + BuiltinOptions_SVDFOptions, + BuiltinOptions_RNNOptions, + BuiltinOptions_FullyConnectedOptions, + BuiltinOptions_SoftmaxOptions, + BuiltinOptions_ConcatenationOptions, + BuiltinOptions_AddOptions, + BuiltinOptions_L2NormOptions, + BuiltinOptions_LocalResponseNormalizationOptions, + BuiltinOptions_LSTMOptions, + BuiltinOptions_ResizeBilinearOptions, + BuiltinOptions_CallOptions, + BuiltinOptions_ReshapeOptions, + BuiltinOptions_SkipGramOptions, + BuiltinOptions_SpaceToDepthOptions, + BuiltinOptions_EmbeddingLookupSparseOptions, + BuiltinOptions_MulOptions, + BuiltinOptions_PadOptions, + BuiltinOptions_GatherOptions, + BuiltinOptions_BatchToSpaceNDOptions, + BuiltinOptions_SpaceToBatchNDOptions, + BuiltinOptions_TransposeOptions, + BuiltinOptions_ReducerOptions, + BuiltinOptions_SubOptions, + BuiltinOptions_DivOptions, + BuiltinOptions_SqueezeOptions, + BuiltinOptions_SequenceRNNOptions, + BuiltinOptions_StridedSliceOptions, + BuiltinOptions_ExpOptions, + BuiltinOptions_TopKV2Options, + BuiltinOptions_SplitOptions, + BuiltinOptions_LogSoftmaxOptions, + BuiltinOptions_CastOptions, + BuiltinOptions_DequantizeOptions, + BuiltinOptions_MaximumMinimumOptions, + BuiltinOptions_ArgMaxOptions, + BuiltinOptions_LessOptions, + BuiltinOptions_NegOptions, + BuiltinOptions_PadV2Options, + BuiltinOptions_GreaterOptions, + BuiltinOptions_GreaterEqualOptions, + BuiltinOptions_LessEqualOptions, + BuiltinOptions_SelectOptions, + BuiltinOptions_SliceOptions, + BuiltinOptions_TransposeConvOptions, + BuiltinOptions_SparseToDenseOptions, + BuiltinOptions_TileOptions, + BuiltinOptions_ExpandDimsOptions, + BuiltinOptions_EqualOptions, + BuiltinOptions_NotEqualOptions, + BuiltinOptions_ShapeOptions, + BuiltinOptions_PowOptions, + BuiltinOptions_ArgMinOptions, + BuiltinOptions_FakeQuantOptions, + BuiltinOptions_PackOptions, + BuiltinOptions_LogicalOrOptions, + BuiltinOptions_OneHotOptions, + BuiltinOptions_LogicalAndOptions, + BuiltinOptions_LogicalNotOptions, + BuiltinOptions_UnpackOptions, + BuiltinOptions_FloorDivOptions, + BuiltinOptions_SquareOptions, + BuiltinOptions_ZerosLikeOptions, + BuiltinOptions_FillOptions, + BuiltinOptions_BidirectionalSequenceLSTMOptions, + BuiltinOptions_BidirectionalSequenceRNNOptions, + BuiltinOptions_UnidirectionalSequenceLSTMOptions, + BuiltinOptions_FloorModOptions, + BuiltinOptions_RangeOptions, + BuiltinOptions_ResizeNearestNeighborOptions, + BuiltinOptions_LeakyReluOptions, + BuiltinOptions_SquaredDifferenceOptions, + BuiltinOptions_MirrorPadOptions, + BuiltinOptions_AbsOptions, + BuiltinOptions_SplitVOptions, + BuiltinOptions_UniqueOptions, + BuiltinOptions_ReverseV2Options, + BuiltinOptions_AddNOptions, + BuiltinOptions_GatherNdOptions, + BuiltinOptions_CosOptions, + BuiltinOptions_WhereOptions, + BuiltinOptions_RankOptions, + BuiltinOptions_ReverseSequenceOptions, + BuiltinOptions_MatrixDiagOptions, + BuiltinOptions_QuantizeOptions, + BuiltinOptions_MatrixSetDiagOptions, + BuiltinOptions_HardSwishOptions, + BuiltinOptions_IfOptions, + BuiltinOptions_WhileOptions, + BuiltinOptions_DepthToSpaceOptions, + BuiltinOptions_NonMaxSuppressionV4Options, + BuiltinOptions_NonMaxSuppressionV5Options, + BuiltinOptions_ScatterNdOptions, + BuiltinOptions_SelectV2Options, + BuiltinOptions_DensifyOptions, + BuiltinOptions_SegmentSumOptions, + BuiltinOptions_BatchMatMulOptions, + BuiltinOptions_CumsumOptions, + BuiltinOptions_CallOnceOptions, + BuiltinOptions_BroadcastToOptions, + BuiltinOptions_Rfft2dOptions, + BuiltinOptions_Conv3DOptions, + BuiltinOptions_HashtableOptions, + BuiltinOptions_HashtableFindOptions, + BuiltinOptions_HashtableImportOptions, + BuiltinOptions_HashtableSizeOptions, + BuiltinOptions_VarHandleOptions, + BuiltinOptions_ReadVariableOptions, + BuiltinOptions_AssignVariableOptions, + BuiltinOptions_RandomOptions, + BuiltinOptions_BucketizeOptions, + BuiltinOptions_GeluOptions, + BuiltinOptions_DynamicUpdateSliceOptions, + BuiltinOptions_UnsortedSegmentProdOptions, + BuiltinOptions_UnsortedSegmentMaxOptions, + BuiltinOptions_UnsortedSegmentSumOptions, + BuiltinOptions_ATan2Options}; + return values; +} + +inline const char *const *EnumNamesBuiltinOptions() +{ + static const char *const names[123] = {"NONE", + "Conv2DOptions", + "DepthwiseConv2DOptions", + "ConcatEmbeddingsOptions", + "LSHProjectionOptions", + "Pool2DOptions", + "SVDFOptions", + "RNNOptions", + "FullyConnectedOptions", + "SoftmaxOptions", + "ConcatenationOptions", + "AddOptions", + "L2NormOptions", + "LocalResponseNormalizationOptions", + "LSTMOptions", + "ResizeBilinearOptions", + "CallOptions", + "ReshapeOptions", + "SkipGramOptions", + "SpaceToDepthOptions", + "EmbeddingLookupSparseOptions", + "MulOptions", + "PadOptions", + "GatherOptions", + "BatchToSpaceNDOptions", + "SpaceToBatchNDOptions", + "TransposeOptions", + "ReducerOptions", + "SubOptions", + "DivOptions", + "SqueezeOptions", + "SequenceRNNOptions", + "StridedSliceOptions", + "ExpOptions", + "TopKV2Options", + "SplitOptions", + "LogSoftmaxOptions", + "CastOptions", + "DequantizeOptions", + "MaximumMinimumOptions", + "ArgMaxOptions", + "LessOptions", + "NegOptions", + "PadV2Options", + "GreaterOptions", + "GreaterEqualOptions", + "LessEqualOptions", + "SelectOptions", + "SliceOptions", + "TransposeConvOptions", + "SparseToDenseOptions", + "TileOptions", + "ExpandDimsOptions", + "EqualOptions", + "NotEqualOptions", + "ShapeOptions", + "PowOptions", + "ArgMinOptions", + "FakeQuantOptions", + "PackOptions", + "LogicalOrOptions", + "OneHotOptions", + "LogicalAndOptions", + "LogicalNotOptions", + "UnpackOptions", + "FloorDivOptions", + "SquareOptions", + "ZerosLikeOptions", + "FillOptions", + "BidirectionalSequenceLSTMOptions", + "BidirectionalSequenceRNNOptions", + "UnidirectionalSequenceLSTMOptions", + "FloorModOptions", + "RangeOptions", + "ResizeNearestNeighborOptions", + "LeakyReluOptions", + "SquaredDifferenceOptions", + "MirrorPadOptions", + "AbsOptions", + "SplitVOptions", + "UniqueOptions", + "ReverseV2Options", + "AddNOptions", + "GatherNdOptions", + "CosOptions", + "WhereOptions", + "RankOptions", + "ReverseSequenceOptions", + "MatrixDiagOptions", + "QuantizeOptions", + "MatrixSetDiagOptions", + "HardSwishOptions", + "IfOptions", + "WhileOptions", + "DepthToSpaceOptions", + "NonMaxSuppressionV4Options", + "NonMaxSuppressionV5Options", + "ScatterNdOptions", + "SelectV2Options", + "DensifyOptions", + "SegmentSumOptions", + "BatchMatMulOptions", + "CumsumOptions", + "CallOnceOptions", + "BroadcastToOptions", + "Rfft2dOptions", + "Conv3DOptions", + "HashtableOptions", + "HashtableFindOptions", + "HashtableImportOptions", + "HashtableSizeOptions", + "VarHandleOptions", + "ReadVariableOptions", + "AssignVariableOptions", + "RandomOptions", + "BucketizeOptions", + "GeluOptions", + "DynamicUpdateSliceOptions", + "UnsortedSegmentProdOptions", + "UnsortedSegmentMaxOptions", + "UnsortedSegmentSumOptions", + "ATan2Options", + nullptr}; + return names; +} + +inline const char *EnumNameBuiltinOptions(BuiltinOptions e) +{ + if (::flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_ATan2Options)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesBuiltinOptions()[index]; +} + +template <typename T> struct BuiltinOptionsTraits +{ + static const BuiltinOptions enum_value = BuiltinOptions_NONE; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::Conv2DOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_Conv2DOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::DepthwiseConv2DOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_DepthwiseConv2DOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ConcatEmbeddingsOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ConcatEmbeddingsOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::LSHProjectionOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_LSHProjectionOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::Pool2DOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_Pool2DOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SVDFOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SVDFOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::RNNOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_RNNOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::FullyConnectedOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_FullyConnectedOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SoftmaxOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SoftmaxOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ConcatenationOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ConcatenationOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::AddOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_AddOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::L2NormOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_L2NormOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::LocalResponseNormalizationOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_LocalResponseNormalizationOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::LSTMOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_LSTMOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ResizeBilinearOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ResizeBilinearOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::CallOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_CallOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ReshapeOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ReshapeOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SkipGramOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SkipGramOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SpaceToDepthOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SpaceToDepthOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::EmbeddingLookupSparseOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_EmbeddingLookupSparseOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::MulOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_MulOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::PadOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_PadOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::GatherOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_GatherOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::BatchToSpaceNDOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_BatchToSpaceNDOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SpaceToBatchNDOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SpaceToBatchNDOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::TransposeOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ReducerOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ReducerOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SubOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SubOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::DivOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_DivOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SqueezeOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SqueezeOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SequenceRNNOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SequenceRNNOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::StridedSliceOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_StridedSliceOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ExpOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ExpOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::TopKV2Options> +{ + static const BuiltinOptions enum_value = BuiltinOptions_TopKV2Options; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SplitOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SplitOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::LogSoftmaxOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_LogSoftmaxOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::CastOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_CastOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::DequantizeOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_DequantizeOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::MaximumMinimumOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_MaximumMinimumOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ArgMaxOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ArgMaxOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::LessOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_LessOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::NegOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_NegOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::PadV2Options> +{ + static const BuiltinOptions enum_value = BuiltinOptions_PadV2Options; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::GreaterOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_GreaterOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::GreaterEqualOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_GreaterEqualOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::LessEqualOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_LessEqualOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SelectOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SelectOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SliceOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SliceOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::TransposeConvOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_TransposeConvOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SparseToDenseOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SparseToDenseOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::TileOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_TileOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ExpandDimsOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::EqualOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_EqualOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::NotEqualOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ShapeOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ShapeOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::PowOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_PowOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ArgMinOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ArgMinOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::FakeQuantOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_FakeQuantOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::PackOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_PackOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::LogicalOrOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::OneHotOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::LogicalAndOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_LogicalAndOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::LogicalNotOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_LogicalNotOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::UnpackOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_UnpackOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::FloorDivOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_FloorDivOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SquareOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SquareOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ZerosLikeOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ZerosLikeOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::FillOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_FillOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::BidirectionalSequenceLSTMOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceLSTMOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::BidirectionalSequenceRNNOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceRNNOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::UnidirectionalSequenceLSTMOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_UnidirectionalSequenceLSTMOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::FloorModOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_FloorModOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::RangeOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_RangeOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ResizeNearestNeighborOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ResizeNearestNeighborOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::LeakyReluOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_LeakyReluOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SquaredDifferenceOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SquaredDifferenceOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::MirrorPadOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_MirrorPadOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::AbsOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_AbsOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SplitVOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SplitVOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::UniqueOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_UniqueOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ReverseV2Options> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ReverseV2Options; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::AddNOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_AddNOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::GatherNdOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_GatherNdOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::CosOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_CosOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::WhereOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_WhereOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::RankOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_RankOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ReverseSequenceOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ReverseSequenceOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::MatrixDiagOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_MatrixDiagOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::QuantizeOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_QuantizeOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::MatrixSetDiagOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_MatrixSetDiagOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::HardSwishOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_HardSwishOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::IfOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_IfOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::WhileOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_WhileOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::DepthToSpaceOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_DepthToSpaceOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::NonMaxSuppressionV4Options> +{ + static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV4Options; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::NonMaxSuppressionV5Options> +{ + static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV5Options; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ScatterNdOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ScatterNdOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SelectV2Options> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SelectV2Options; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::DensifyOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_DensifyOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::SegmentSumOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::BatchMatMulOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_BatchMatMulOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::CumsumOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_CumsumOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::CallOnceOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_CallOnceOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::BroadcastToOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_BroadcastToOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::Rfft2dOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_Rfft2dOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::Conv3DOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_Conv3DOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::HashtableOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_HashtableOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::HashtableFindOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_HashtableFindOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::HashtableImportOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_HashtableImportOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::HashtableSizeOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_HashtableSizeOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::VarHandleOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_VarHandleOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ReadVariableOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ReadVariableOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::AssignVariableOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_AssignVariableOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::RandomOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_RandomOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::BucketizeOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_BucketizeOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::GeluOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_GeluOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::DynamicUpdateSliceOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_DynamicUpdateSliceOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::UnsortedSegmentProdOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentProdOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::UnsortedSegmentMaxOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentMaxOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::UnsortedSegmentSumOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentSumOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ATan2Options> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ATan2Options; +}; + +bool VerifyBuiltinOptions(::flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); +bool VerifyBuiltinOptionsVector(::flatbuffers::Verifier &verifier, + const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, + const ::flatbuffers::Vector<uint8_t> *types); + +enum Padding : int8_t +{ + Padding_SAME = 0, + Padding_VALID = 1, + Padding_MIN = Padding_SAME, + Padding_MAX = Padding_VALID +}; + +inline const Padding (&EnumValuesPadding())[2] +{ + static const Padding values[] = {Padding_SAME, Padding_VALID}; + return values; +} + +inline const char *const *EnumNamesPadding() +{ + static const char *const names[3] = {"SAME", "VALID", nullptr}; + return names; +} + +inline const char *EnumNamePadding(Padding e) +{ + if (::flatbuffers::IsOutRange(e, Padding_SAME, Padding_VALID)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesPadding()[index]; +} + +enum ActivationFunctionType : int8_t +{ + ActivationFunctionType_NONE = 0, + ActivationFunctionType_RELU = 1, + ActivationFunctionType_RELU_N1_TO_1 = 2, + ActivationFunctionType_RELU6 = 3, + ActivationFunctionType_TANH = 4, + ActivationFunctionType_SIGN_BIT = 5, + ActivationFunctionType_MIN = ActivationFunctionType_NONE, + ActivationFunctionType_MAX = ActivationFunctionType_SIGN_BIT +}; + +inline const ActivationFunctionType (&EnumValuesActivationFunctionType())[6] +{ + static const ActivationFunctionType values[] = { + ActivationFunctionType_NONE, ActivationFunctionType_RELU, ActivationFunctionType_RELU_N1_TO_1, + ActivationFunctionType_RELU6, ActivationFunctionType_TANH, ActivationFunctionType_SIGN_BIT}; + return values; +} + +inline const char *const *EnumNamesActivationFunctionType() +{ + static const char *const names[7] = {"NONE", "RELU", "RELU_N1_TO_1", "RELU6", + "TANH", "SIGN_BIT", nullptr}; + return names; +} + +inline const char *EnumNameActivationFunctionType(ActivationFunctionType e) +{ + if (::flatbuffers::IsOutRange(e, ActivationFunctionType_NONE, ActivationFunctionType_SIGN_BIT)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesActivationFunctionType()[index]; +} + +enum LSHProjectionType : int8_t +{ + LSHProjectionType_UNKNOWN = 0, + LSHProjectionType_SPARSE = 1, + LSHProjectionType_DENSE = 2, + LSHProjectionType_MIN = LSHProjectionType_UNKNOWN, + LSHProjectionType_MAX = LSHProjectionType_DENSE +}; + +inline const LSHProjectionType (&EnumValuesLSHProjectionType())[3] +{ + static const LSHProjectionType values[] = {LSHProjectionType_UNKNOWN, LSHProjectionType_SPARSE, + LSHProjectionType_DENSE}; + return values; +} + +inline const char *const *EnumNamesLSHProjectionType() +{ + static const char *const names[4] = {"UNKNOWN", "SPARSE", "DENSE", nullptr}; + return names; +} + +inline const char *EnumNameLSHProjectionType(LSHProjectionType e) +{ + if (::flatbuffers::IsOutRange(e, LSHProjectionType_UNKNOWN, LSHProjectionType_DENSE)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesLSHProjectionType()[index]; +} + +enum FullyConnectedOptionsWeightsFormat : int8_t +{ + FullyConnectedOptionsWeightsFormat_DEFAULT = 0, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 = 1, + FullyConnectedOptionsWeightsFormat_MIN = FullyConnectedOptionsWeightsFormat_DEFAULT, + FullyConnectedOptionsWeightsFormat_MAX = FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 +}; + +inline const FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOptionsWeightsFormat())[2] +{ + static const FullyConnectedOptionsWeightsFormat values[] = { + FullyConnectedOptionsWeightsFormat_DEFAULT, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8}; + return values; +} + +inline const char *const *EnumNamesFullyConnectedOptionsWeightsFormat() +{ + static const char *const names[3] = {"DEFAULT", "SHUFFLED4x16INT8", nullptr}; + return names; +} + +inline const char *EnumNameFullyConnectedOptionsWeightsFormat(FullyConnectedOptionsWeightsFormat e) +{ + if (::flatbuffers::IsOutRange(e, FullyConnectedOptionsWeightsFormat_DEFAULT, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesFullyConnectedOptionsWeightsFormat()[index]; +} + +enum LSTMKernelType : int8_t +{ + LSTMKernelType_FULL = 0, + LSTMKernelType_BASIC = 1, + LSTMKernelType_MIN = LSTMKernelType_FULL, + LSTMKernelType_MAX = LSTMKernelType_BASIC +}; + +inline const LSTMKernelType (&EnumValuesLSTMKernelType())[2] +{ + static const LSTMKernelType values[] = {LSTMKernelType_FULL, LSTMKernelType_BASIC}; + return values; +} + +inline const char *const *EnumNamesLSTMKernelType() +{ + static const char *const names[3] = {"FULL", "BASIC", nullptr}; + return names; +} + +inline const char *EnumNameLSTMKernelType(LSTMKernelType e) +{ + if (::flatbuffers::IsOutRange(e, LSTMKernelType_FULL, LSTMKernelType_BASIC)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesLSTMKernelType()[index]; +} + +enum CombinerType : int8_t +{ + CombinerType_SUM = 0, + CombinerType_MEAN = 1, + CombinerType_SQRTN = 2, + CombinerType_MIN = CombinerType_SUM, + CombinerType_MAX = CombinerType_SQRTN +}; + +inline const CombinerType (&EnumValuesCombinerType())[3] +{ + static const CombinerType values[] = {CombinerType_SUM, CombinerType_MEAN, CombinerType_SQRTN}; + return values; +} + +inline const char *const *EnumNamesCombinerType() +{ + static const char *const names[4] = {"SUM", "MEAN", "SQRTN", nullptr}; + return names; +} + +inline const char *EnumNameCombinerType(CombinerType e) +{ + if (::flatbuffers::IsOutRange(e, CombinerType_SUM, CombinerType_SQRTN)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesCombinerType()[index]; +} + +enum MirrorPadMode : int8_t +{ + MirrorPadMode_REFLECT = 0, + MirrorPadMode_SYMMETRIC = 1, + MirrorPadMode_MIN = MirrorPadMode_REFLECT, + MirrorPadMode_MAX = MirrorPadMode_SYMMETRIC +}; + +inline const MirrorPadMode (&EnumValuesMirrorPadMode())[2] +{ + static const MirrorPadMode values[] = {MirrorPadMode_REFLECT, MirrorPadMode_SYMMETRIC}; + return values; +} + +inline const char *const *EnumNamesMirrorPadMode() +{ + static const char *const names[3] = {"REFLECT", "SYMMETRIC", nullptr}; + return names; +} + +inline const char *EnumNameMirrorPadMode(MirrorPadMode e) +{ + if (::flatbuffers::IsOutRange(e, MirrorPadMode_REFLECT, MirrorPadMode_SYMMETRIC)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesMirrorPadMode()[index]; +} + +enum CustomOptionsFormat : int8_t +{ + CustomOptionsFormat_FLEXBUFFERS = 0, + CustomOptionsFormat_MIN = CustomOptionsFormat_FLEXBUFFERS, + CustomOptionsFormat_MAX = CustomOptionsFormat_FLEXBUFFERS +}; + +inline const CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1] +{ + static const CustomOptionsFormat values[] = {CustomOptionsFormat_FLEXBUFFERS}; + return values; +} + +inline const char *const *EnumNamesCustomOptionsFormat() +{ + static const char *const names[2] = {"FLEXBUFFERS", nullptr}; + return names; +} + +inline const char *EnumNameCustomOptionsFormat(CustomOptionsFormat e) +{ + if (::flatbuffers::IsOutRange(e, CustomOptionsFormat_FLEXBUFFERS, + CustomOptionsFormat_FLEXBUFFERS)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesCustomOptionsFormat()[index]; +} + +struct CustomQuantization FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef CustomQuantizationBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_CUSTOM = 4 + }; + const ::flatbuffers::Vector<uint8_t> *custom() const + { + return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_CUSTOM); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_CUSTOM) && + verifier.VerifyVector(custom()) && verifier.EndTable(); + } +}; + +struct CustomQuantizationBuilder +{ + typedef CustomQuantization Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_custom(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> custom) + { + fbb_.AddOffset(CustomQuantization::VT_CUSTOM, custom); + } + explicit CustomQuantizationBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<CustomQuantization> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<CustomQuantization>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<CustomQuantization> +CreateCustomQuantization(::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> custom = 0) +{ + CustomQuantizationBuilder builder_(_fbb); + builder_.add_custom(custom); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<CustomQuantization> +CreateCustomQuantizationDirect(::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<uint8_t> *custom = nullptr) +{ + if (custom) + { + _fbb.ForceVectorAlignment(custom->size(), sizeof(uint8_t), 16); + } + auto custom__ = custom ? _fbb.CreateVector<uint8_t>(*custom) : 0; + return onert_tflite::CreateCustomQuantization(_fbb, custom__); +} + +struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef QuantizationParametersBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_MIN = 4, + VT_MAX = 6, + VT_SCALE = 8, + VT_ZERO_POINT = 10, + VT_DETAILS_TYPE = 12, + VT_DETAILS = 14, + VT_QUANTIZED_DIMENSION = 16 + }; + const ::flatbuffers::Vector<float> *min() const + { + return GetPointer<const ::flatbuffers::Vector<float> *>(VT_MIN); + } + const ::flatbuffers::Vector<float> *max() const + { + return GetPointer<const ::flatbuffers::Vector<float> *>(VT_MAX); + } + const ::flatbuffers::Vector<float> *scale() const + { + return GetPointer<const ::flatbuffers::Vector<float> *>(VT_SCALE); + } + const ::flatbuffers::Vector<int64_t> *zero_point() const + { + return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_ZERO_POINT); + } + onert_tflite::QuantizationDetails details_type() const + { + return static_cast<onert_tflite::QuantizationDetails>(GetField<uint8_t>(VT_DETAILS_TYPE, 0)); + } + const void *details() const { return GetPointer<const void *>(VT_DETAILS); } + template <typename T> const T *details_as() const; + const onert_tflite::CustomQuantization *details_as_CustomQuantization() const + { + return details_type() == onert_tflite::QuantizationDetails_CustomQuantization + ? static_cast<const onert_tflite::CustomQuantization *>(details()) + : nullptr; + } + int32_t quantized_dimension() const { return GetField<int32_t>(VT_QUANTIZED_DIMENSION, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_MIN) && + verifier.VerifyVector(min()) && VerifyOffset(verifier, VT_MAX) && + verifier.VerifyVector(max()) && VerifyOffset(verifier, VT_SCALE) && + verifier.VerifyVector(scale()) && VerifyOffset(verifier, VT_ZERO_POINT) && + verifier.VerifyVector(zero_point()) && + VerifyField<uint8_t>(verifier, VT_DETAILS_TYPE, 1) && + VerifyOffset(verifier, VT_DETAILS) && + VerifyQuantizationDetails(verifier, details(), details_type()) && + VerifyField<int32_t>(verifier, VT_QUANTIZED_DIMENSION, 4) && verifier.EndTable(); + } +}; + +template <> +inline const onert_tflite::CustomQuantization * +QuantizationParameters::details_as<onert_tflite::CustomQuantization>() const +{ + return details_as_CustomQuantization(); +} + +struct QuantizationParametersBuilder +{ + typedef QuantizationParameters Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_min(::flatbuffers::Offset<::flatbuffers::Vector<float>> min) + { + fbb_.AddOffset(QuantizationParameters::VT_MIN, min); + } + void add_max(::flatbuffers::Offset<::flatbuffers::Vector<float>> max) + { + fbb_.AddOffset(QuantizationParameters::VT_MAX, max); + } + void add_scale(::flatbuffers::Offset<::flatbuffers::Vector<float>> scale) + { + fbb_.AddOffset(QuantizationParameters::VT_SCALE, scale); + } + void add_zero_point(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> zero_point) + { + fbb_.AddOffset(QuantizationParameters::VT_ZERO_POINT, zero_point); + } + void add_details_type(onert_tflite::QuantizationDetails details_type) + { + fbb_.AddElement<uint8_t>(QuantizationParameters::VT_DETAILS_TYPE, + static_cast<uint8_t>(details_type), 0); + } + void add_details(::flatbuffers::Offset<void> details) + { + fbb_.AddOffset(QuantizationParameters::VT_DETAILS, details); + } + void add_quantized_dimension(int32_t quantized_dimension) + { + fbb_.AddElement<int32_t>(QuantizationParameters::VT_QUANTIZED_DIMENSION, quantized_dimension, + 0); + } + explicit QuantizationParametersBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<QuantizationParameters> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<QuantizationParameters>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<QuantizationParameters> CreateQuantizationParameters( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<float>> min = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<float>> max = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<float>> scale = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> zero_point = 0, + onert_tflite::QuantizationDetails details_type = onert_tflite::QuantizationDetails_NONE, + ::flatbuffers::Offset<void> details = 0, int32_t quantized_dimension = 0) +{ + QuantizationParametersBuilder builder_(_fbb); + builder_.add_quantized_dimension(quantized_dimension); + builder_.add_details(details); + builder_.add_zero_point(zero_point); + builder_.add_scale(scale); + builder_.add_max(max); + builder_.add_min(min); + builder_.add_details_type(details_type); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<QuantizationParameters> CreateQuantizationParametersDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, const std::vector<float> *min = nullptr, + const std::vector<float> *max = nullptr, const std::vector<float> *scale = nullptr, + const std::vector<int64_t> *zero_point = nullptr, + onert_tflite::QuantizationDetails details_type = onert_tflite::QuantizationDetails_NONE, + ::flatbuffers::Offset<void> details = 0, int32_t quantized_dimension = 0) +{ + auto min__ = min ? _fbb.CreateVector<float>(*min) : 0; + auto max__ = max ? _fbb.CreateVector<float>(*max) : 0; + auto scale__ = scale ? _fbb.CreateVector<float>(*scale) : 0; + auto zero_point__ = zero_point ? _fbb.CreateVector<int64_t>(*zero_point) : 0; + return onert_tflite::CreateQuantizationParameters(_fbb, min__, max__, scale__, zero_point__, + details_type, details, quantized_dimension); +} + +struct Int32Vector FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef Int32VectorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_VALUES = 4 + }; + const ::flatbuffers::Vector<int32_t> *values() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_VALUES); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_VALUES) && + verifier.VerifyVector(values()) && verifier.EndTable(); + } +}; + +struct Int32VectorBuilder +{ + typedef Int32Vector Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_values(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> values) + { + fbb_.AddOffset(Int32Vector::VT_VALUES, values); + } + explicit Int32VectorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<Int32Vector> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<Int32Vector>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<Int32Vector> +CreateInt32Vector(::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> values = 0) +{ + Int32VectorBuilder builder_(_fbb); + builder_.add_values(values); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<Int32Vector> +CreateInt32VectorDirect(::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<int32_t> *values = nullptr) +{ + auto values__ = values ? _fbb.CreateVector<int32_t>(*values) : 0; + return onert_tflite::CreateInt32Vector(_fbb, values__); +} + +struct Uint16Vector FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef Uint16VectorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_VALUES = 4 + }; + const ::flatbuffers::Vector<uint16_t> *values() const + { + return GetPointer<const ::flatbuffers::Vector<uint16_t> *>(VT_VALUES); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_VALUES) && + verifier.VerifyVector(values()) && verifier.EndTable(); + } +}; + +struct Uint16VectorBuilder +{ + typedef Uint16Vector Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_values(::flatbuffers::Offset<::flatbuffers::Vector<uint16_t>> values) + { + fbb_.AddOffset(Uint16Vector::VT_VALUES, values); + } + explicit Uint16VectorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<Uint16Vector> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<Uint16Vector>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<Uint16Vector> +CreateUint16Vector(::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<uint16_t>> values = 0) +{ + Uint16VectorBuilder builder_(_fbb); + builder_.add_values(values); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<Uint16Vector> +CreateUint16VectorDirect(::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<uint16_t> *values = nullptr) +{ + if (values) + { + _fbb.ForceVectorAlignment(values->size(), sizeof(uint16_t), 4); + } + auto values__ = values ? _fbb.CreateVector<uint16_t>(*values) : 0; + return onert_tflite::CreateUint16Vector(_fbb, values__); +} + +struct Uint8Vector FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef Uint8VectorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_VALUES = 4 + }; + const ::flatbuffers::Vector<uint8_t> *values() const + { + return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_VALUES); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_VALUES) && + verifier.VerifyVector(values()) && verifier.EndTable(); + } +}; + +struct Uint8VectorBuilder +{ + typedef Uint8Vector Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_values(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> values) + { + fbb_.AddOffset(Uint8Vector::VT_VALUES, values); + } + explicit Uint8VectorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<Uint8Vector> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<Uint8Vector>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<Uint8Vector> +CreateUint8Vector(::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> values = 0) +{ + Uint8VectorBuilder builder_(_fbb); + builder_.add_values(values); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<Uint8Vector> +CreateUint8VectorDirect(::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<uint8_t> *values = nullptr) +{ + if (values) + { + _fbb.ForceVectorAlignment(values->size(), sizeof(uint8_t), 4); + } + auto values__ = values ? _fbb.CreateVector<uint8_t>(*values) : 0; + return onert_tflite::CreateUint8Vector(_fbb, values__); +} + +struct DimensionMetadata FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef DimensionMetadataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_FORMAT = 4, + VT_DENSE_SIZE = 6, + VT_ARRAY_SEGMENTS_TYPE = 8, + VT_ARRAY_SEGMENTS = 10, + VT_ARRAY_INDICES_TYPE = 12, + VT_ARRAY_INDICES = 14 + }; + onert_tflite::DimensionType format() const + { + return static_cast<onert_tflite::DimensionType>(GetField<int8_t>(VT_FORMAT, 0)); + } + int32_t dense_size() const { return GetField<int32_t>(VT_DENSE_SIZE, 0); } + onert_tflite::SparseIndexVector array_segments_type() const + { + return static_cast<onert_tflite::SparseIndexVector>( + GetField<uint8_t>(VT_ARRAY_SEGMENTS_TYPE, 0)); + } + const void *array_segments() const { return GetPointer<const void *>(VT_ARRAY_SEGMENTS); } + template <typename T> const T *array_segments_as() const; + const onert_tflite::Int32Vector *array_segments_as_Int32Vector() const + { + return array_segments_type() == onert_tflite::SparseIndexVector_Int32Vector + ? static_cast<const onert_tflite::Int32Vector *>(array_segments()) + : nullptr; + } + const onert_tflite::Uint16Vector *array_segments_as_Uint16Vector() const + { + return array_segments_type() == onert_tflite::SparseIndexVector_Uint16Vector + ? static_cast<const onert_tflite::Uint16Vector *>(array_segments()) + : nullptr; + } + const onert_tflite::Uint8Vector *array_segments_as_Uint8Vector() const + { + return array_segments_type() == onert_tflite::SparseIndexVector_Uint8Vector + ? static_cast<const onert_tflite::Uint8Vector *>(array_segments()) + : nullptr; + } + onert_tflite::SparseIndexVector array_indices_type() const + { + return static_cast<onert_tflite::SparseIndexVector>( + GetField<uint8_t>(VT_ARRAY_INDICES_TYPE, 0)); + } + const void *array_indices() const { return GetPointer<const void *>(VT_ARRAY_INDICES); } + template <typename T> const T *array_indices_as() const; + const onert_tflite::Int32Vector *array_indices_as_Int32Vector() const + { + return array_indices_type() == onert_tflite::SparseIndexVector_Int32Vector + ? static_cast<const onert_tflite::Int32Vector *>(array_indices()) + : nullptr; + } + const onert_tflite::Uint16Vector *array_indices_as_Uint16Vector() const + { + return array_indices_type() == onert_tflite::SparseIndexVector_Uint16Vector + ? static_cast<const onert_tflite::Uint16Vector *>(array_indices()) + : nullptr; + } + const onert_tflite::Uint8Vector *array_indices_as_Uint8Vector() const + { + return array_indices_type() == onert_tflite::SparseIndexVector_Uint8Vector + ? static_cast<const onert_tflite::Uint8Vector *>(array_indices()) + : nullptr; + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_FORMAT, 1) && + VerifyField<int32_t>(verifier, VT_DENSE_SIZE, 4) && + VerifyField<uint8_t>(verifier, VT_ARRAY_SEGMENTS_TYPE, 1) && + VerifyOffset(verifier, VT_ARRAY_SEGMENTS) && + VerifySparseIndexVector(verifier, array_segments(), array_segments_type()) && + VerifyField<uint8_t>(verifier, VT_ARRAY_INDICES_TYPE, 1) && + VerifyOffset(verifier, VT_ARRAY_INDICES) && + VerifySparseIndexVector(verifier, array_indices(), array_indices_type()) && + verifier.EndTable(); + } +}; + +template <> +inline const onert_tflite::Int32Vector * +DimensionMetadata::array_segments_as<onert_tflite::Int32Vector>() const +{ + return array_segments_as_Int32Vector(); +} + +template <> +inline const onert_tflite::Uint16Vector * +DimensionMetadata::array_segments_as<onert_tflite::Uint16Vector>() const +{ + return array_segments_as_Uint16Vector(); +} + +template <> +inline const onert_tflite::Uint8Vector * +DimensionMetadata::array_segments_as<onert_tflite::Uint8Vector>() const +{ + return array_segments_as_Uint8Vector(); +} + +template <> +inline const onert_tflite::Int32Vector * +DimensionMetadata::array_indices_as<onert_tflite::Int32Vector>() const +{ + return array_indices_as_Int32Vector(); +} + +template <> +inline const onert_tflite::Uint16Vector * +DimensionMetadata::array_indices_as<onert_tflite::Uint16Vector>() const +{ + return array_indices_as_Uint16Vector(); +} + +template <> +inline const onert_tflite::Uint8Vector * +DimensionMetadata::array_indices_as<onert_tflite::Uint8Vector>() const +{ + return array_indices_as_Uint8Vector(); +} + +struct DimensionMetadataBuilder +{ + typedef DimensionMetadata Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_format(onert_tflite::DimensionType format) + { + fbb_.AddElement<int8_t>(DimensionMetadata::VT_FORMAT, static_cast<int8_t>(format), 0); + } + void add_dense_size(int32_t dense_size) + { + fbb_.AddElement<int32_t>(DimensionMetadata::VT_DENSE_SIZE, dense_size, 0); + } + void add_array_segments_type(onert_tflite::SparseIndexVector array_segments_type) + { + fbb_.AddElement<uint8_t>(DimensionMetadata::VT_ARRAY_SEGMENTS_TYPE, + static_cast<uint8_t>(array_segments_type), 0); + } + void add_array_segments(::flatbuffers::Offset<void> array_segments) + { + fbb_.AddOffset(DimensionMetadata::VT_ARRAY_SEGMENTS, array_segments); + } + void add_array_indices_type(onert_tflite::SparseIndexVector array_indices_type) + { + fbb_.AddElement<uint8_t>(DimensionMetadata::VT_ARRAY_INDICES_TYPE, + static_cast<uint8_t>(array_indices_type), 0); + } + void add_array_indices(::flatbuffers::Offset<void> array_indices) + { + fbb_.AddOffset(DimensionMetadata::VT_ARRAY_INDICES, array_indices); + } + explicit DimensionMetadataBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<DimensionMetadata> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<DimensionMetadata>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<DimensionMetadata> CreateDimensionMetadata( + ::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::DimensionType format = onert_tflite::DimensionType_DENSE, int32_t dense_size = 0, + onert_tflite::SparseIndexVector array_segments_type = onert_tflite::SparseIndexVector_NONE, + ::flatbuffers::Offset<void> array_segments = 0, + onert_tflite::SparseIndexVector array_indices_type = onert_tflite::SparseIndexVector_NONE, + ::flatbuffers::Offset<void> array_indices = 0) +{ + DimensionMetadataBuilder builder_(_fbb); + builder_.add_array_indices(array_indices); + builder_.add_array_segments(array_segments); + builder_.add_dense_size(dense_size); + builder_.add_array_indices_type(array_indices_type); + builder_.add_array_segments_type(array_segments_type); + builder_.add_format(format); + return builder_.Finish(); +} + +struct SparsityParameters FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SparsityParametersBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_TRAVERSAL_ORDER = 4, + VT_BLOCK_MAP = 6, + VT_DIM_METADATA = 8 + }; + const ::flatbuffers::Vector<int32_t> *traversal_order() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_TRAVERSAL_ORDER); + } + const ::flatbuffers::Vector<int32_t> *block_map() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_BLOCK_MAP); + } + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::DimensionMetadata>> * + dim_metadata() const + { + return GetPointer< + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::DimensionMetadata>> *>( + VT_DIM_METADATA); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_TRAVERSAL_ORDER) && + verifier.VerifyVector(traversal_order()) && VerifyOffset(verifier, VT_BLOCK_MAP) && + verifier.VerifyVector(block_map()) && VerifyOffset(verifier, VT_DIM_METADATA) && + verifier.VerifyVector(dim_metadata()) && verifier.VerifyVectorOfTables(dim_metadata()) && + verifier.EndTable(); + } +}; + +struct SparsityParametersBuilder +{ + typedef SparsityParameters Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_traversal_order(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> traversal_order) + { + fbb_.AddOffset(SparsityParameters::VT_TRAVERSAL_ORDER, traversal_order); + } + void add_block_map(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> block_map) + { + fbb_.AddOffset(SparsityParameters::VT_BLOCK_MAP, block_map); + } + void + add_dim_metadata(::flatbuffers::Offset< + ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::DimensionMetadata>>> + dim_metadata) + { + fbb_.AddOffset(SparsityParameters::VT_DIM_METADATA, dim_metadata); + } + explicit SparsityParametersBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SparsityParameters> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SparsityParameters>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SparsityParameters> CreateSparsityParameters( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> traversal_order = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> block_map = 0, + ::flatbuffers::Offset< + ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::DimensionMetadata>>> + dim_metadata = 0) +{ + SparsityParametersBuilder builder_(_fbb); + builder_.add_dim_metadata(dim_metadata); + builder_.add_block_map(block_map); + builder_.add_traversal_order(traversal_order); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<SparsityParameters> CreateSparsityParametersDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, const std::vector<int32_t> *traversal_order = nullptr, + const std::vector<int32_t> *block_map = nullptr, + const std::vector<::flatbuffers::Offset<onert_tflite::DimensionMetadata>> *dim_metadata = nullptr) +{ + auto traversal_order__ = traversal_order ? _fbb.CreateVector<int32_t>(*traversal_order) : 0; + auto block_map__ = block_map ? _fbb.CreateVector<int32_t>(*block_map) : 0; + auto dim_metadata__ = + dim_metadata + ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::DimensionMetadata>>(*dim_metadata) + : 0; + return onert_tflite::CreateSparsityParameters(_fbb, traversal_order__, block_map__, + dim_metadata__); +} + +struct Tensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef TensorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_SHAPE = 4, + VT_TYPE = 6, + VT_BUFFER = 8, + VT_NAME = 10, + VT_QUANTIZATION = 12, + VT_IS_VARIABLE = 14, + VT_SPARSITY = 16, + VT_SHAPE_SIGNATURE = 18, + VT_HAS_RANK = 20 + }; + const ::flatbuffers::Vector<int32_t> *shape() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_SHAPE); + } + onert_tflite::TensorType type() const + { + return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_TYPE, 0)); + } + uint32_t buffer() const { return GetField<uint32_t>(VT_BUFFER, 0); } + const ::flatbuffers::String *name() const + { + return GetPointer<const ::flatbuffers::String *>(VT_NAME); + } + const onert_tflite::QuantizationParameters *quantization() const + { + return GetPointer<const onert_tflite::QuantizationParameters *>(VT_QUANTIZATION); + } + bool is_variable() const { return GetField<uint8_t>(VT_IS_VARIABLE, 0) != 0; } + const onert_tflite::SparsityParameters *sparsity() const + { + return GetPointer<const onert_tflite::SparsityParameters *>(VT_SPARSITY); + } + const ::flatbuffers::Vector<int32_t> *shape_signature() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_SHAPE_SIGNATURE); + } + bool has_rank() const { return GetField<uint8_t>(VT_HAS_RANK, 0) != 0; } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && + verifier.VerifyVector(shape()) && VerifyField<int8_t>(verifier, VT_TYPE, 1) && + VerifyField<uint32_t>(verifier, VT_BUFFER, 4) && VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && VerifyOffset(verifier, VT_QUANTIZATION) && + verifier.VerifyTable(quantization()) && + VerifyField<uint8_t>(verifier, VT_IS_VARIABLE, 1) && + VerifyOffset(verifier, VT_SPARSITY) && verifier.VerifyTable(sparsity()) && + VerifyOffset(verifier, VT_SHAPE_SIGNATURE) && verifier.VerifyVector(shape_signature()) && + VerifyField<uint8_t>(verifier, VT_HAS_RANK, 1) && verifier.EndTable(); + } +}; + +struct TensorBuilder +{ + typedef Tensor Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_shape(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> shape) + { + fbb_.AddOffset(Tensor::VT_SHAPE, shape); + } + void add_type(onert_tflite::TensorType type) + { + fbb_.AddElement<int8_t>(Tensor::VT_TYPE, static_cast<int8_t>(type), 0); + } + void add_buffer(uint32_t buffer) { fbb_.AddElement<uint32_t>(Tensor::VT_BUFFER, buffer, 0); } + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) + { + fbb_.AddOffset(Tensor::VT_NAME, name); + } + void add_quantization(::flatbuffers::Offset<onert_tflite::QuantizationParameters> quantization) + { + fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization); + } + void add_is_variable(bool is_variable) + { + fbb_.AddElement<uint8_t>(Tensor::VT_IS_VARIABLE, static_cast<uint8_t>(is_variable), 0); + } + void add_sparsity(::flatbuffers::Offset<onert_tflite::SparsityParameters> sparsity) + { + fbb_.AddOffset(Tensor::VT_SPARSITY, sparsity); + } + void add_shape_signature(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> shape_signature) + { + fbb_.AddOffset(Tensor::VT_SHAPE_SIGNATURE, shape_signature); + } + void add_has_rank(bool has_rank) + { + fbb_.AddElement<uint8_t>(Tensor::VT_HAS_RANK, static_cast<uint8_t>(has_rank), 0); + } + explicit TensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<Tensor> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<Tensor>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<Tensor> CreateTensor( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> shape = 0, + onert_tflite::TensorType type = onert_tflite::TensorType_FLOAT32, uint32_t buffer = 0, + ::flatbuffers::Offset<::flatbuffers::String> name = 0, + ::flatbuffers::Offset<onert_tflite::QuantizationParameters> quantization = 0, + bool is_variable = false, ::flatbuffers::Offset<onert_tflite::SparsityParameters> sparsity = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> shape_signature = 0, bool has_rank = false) +{ + TensorBuilder builder_(_fbb); + builder_.add_shape_signature(shape_signature); + builder_.add_sparsity(sparsity); + builder_.add_quantization(quantization); + builder_.add_name(name); + builder_.add_buffer(buffer); + builder_.add_shape(shape); + builder_.add_has_rank(has_rank); + builder_.add_is_variable(is_variable); + builder_.add_type(type); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<Tensor> CreateTensorDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, const std::vector<int32_t> *shape = nullptr, + onert_tflite::TensorType type = onert_tflite::TensorType_FLOAT32, uint32_t buffer = 0, + const char *name = nullptr, + ::flatbuffers::Offset<onert_tflite::QuantizationParameters> quantization = 0, + bool is_variable = false, ::flatbuffers::Offset<onert_tflite::SparsityParameters> sparsity = 0, + const std::vector<int32_t> *shape_signature = nullptr, bool has_rank = false) +{ + auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0; + auto name__ = name ? _fbb.CreateString(name) : 0; + auto shape_signature__ = shape_signature ? _fbb.CreateVector<int32_t>(*shape_signature) : 0; + return onert_tflite::CreateTensor(_fbb, shape__, type, buffer, name__, quantization, is_variable, + sparsity, shape_signature__, has_rank); +} + +struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef Conv2DOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8, + VT_FUSED_ACTIVATION_FUNCTION = 10, + VT_DILATION_W_FACTOR = 12, + VT_DILATION_H_FACTOR = 14 + }; + onert_tflite::Padding padding() const + { + return static_cast<onert_tflite::Padding>(GetField<int8_t>(VT_PADDING, 0)); + } + int32_t stride_w() const { return GetField<int32_t>(VT_STRIDE_W, 0); } + int32_t stride_h() const { return GetField<int32_t>(VT_STRIDE_H, 0); } + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + int32_t dilation_w_factor() const { return GetField<int32_t>(VT_DILATION_W_FACTOR, 1); } + int32_t dilation_h_factor() const { return GetField<int32_t>(VT_DILATION_H_FACTOR, 1); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_PADDING, 1) && + VerifyField<int32_t>(verifier, VT_STRIDE_W, 4) && + VerifyField<int32_t>(verifier, VT_STRIDE_H, 4) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR, 4) && + VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR, 4) && verifier.EndTable(); + } +}; + +struct Conv2DOptionsBuilder +{ + typedef Conv2DOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_padding(onert_tflite::Padding padding) + { + fbb_.AddElement<int8_t>(Conv2DOptions::VT_PADDING, static_cast<int8_t>(padding), 0); + } + void add_stride_w(int32_t stride_w) + { + fbb_.AddElement<int32_t>(Conv2DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) + { + fbb_.AddElement<int32_t>(Conv2DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(Conv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + void add_dilation_w_factor(int32_t dilation_w_factor) + { + fbb_.AddElement<int32_t>(Conv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1); + } + void add_dilation_h_factor(int32_t dilation_h_factor) + { + fbb_.AddElement<int32_t>(Conv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1); + } + explicit Conv2DOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<Conv2DOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<Conv2DOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<Conv2DOptions> +CreateConv2DOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::Padding padding = onert_tflite::Padding_SAME, + int32_t stride_w = 0, int32_t stride_h = 0, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE, + int32_t dilation_w_factor = 1, int32_t dilation_h_factor = 1) +{ + Conv2DOptionsBuilder builder_(_fbb); + builder_.add_dilation_h_factor(dilation_h_factor); + builder_.add_dilation_w_factor(dilation_w_factor); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +struct Conv3DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef Conv3DOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_PADDING = 4, + VT_STRIDE_D = 6, + VT_STRIDE_W = 8, + VT_STRIDE_H = 10, + VT_FUSED_ACTIVATION_FUNCTION = 12, + VT_DILATION_D_FACTOR = 14, + VT_DILATION_W_FACTOR = 16, + VT_DILATION_H_FACTOR = 18 + }; + onert_tflite::Padding padding() const + { + return static_cast<onert_tflite::Padding>(GetField<int8_t>(VT_PADDING, 0)); + } + int32_t stride_d() const { return GetField<int32_t>(VT_STRIDE_D, 0); } + int32_t stride_w() const { return GetField<int32_t>(VT_STRIDE_W, 0); } + int32_t stride_h() const { return GetField<int32_t>(VT_STRIDE_H, 0); } + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + int32_t dilation_d_factor() const { return GetField<int32_t>(VT_DILATION_D_FACTOR, 1); } + int32_t dilation_w_factor() const { return GetField<int32_t>(VT_DILATION_W_FACTOR, 1); } + int32_t dilation_h_factor() const { return GetField<int32_t>(VT_DILATION_H_FACTOR, 1); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_PADDING, 1) && + VerifyField<int32_t>(verifier, VT_STRIDE_D, 4) && + VerifyField<int32_t>(verifier, VT_STRIDE_W, 4) && + VerifyField<int32_t>(verifier, VT_STRIDE_H, 4) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField<int32_t>(verifier, VT_DILATION_D_FACTOR, 4) && + VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR, 4) && + VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR, 4) && verifier.EndTable(); + } +}; + +struct Conv3DOptionsBuilder +{ + typedef Conv3DOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_padding(onert_tflite::Padding padding) + { + fbb_.AddElement<int8_t>(Conv3DOptions::VT_PADDING, static_cast<int8_t>(padding), 0); + } + void add_stride_d(int32_t stride_d) + { + fbb_.AddElement<int32_t>(Conv3DOptions::VT_STRIDE_D, stride_d, 0); + } + void add_stride_w(int32_t stride_w) + { + fbb_.AddElement<int32_t>(Conv3DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) + { + fbb_.AddElement<int32_t>(Conv3DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(Conv3DOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + void add_dilation_d_factor(int32_t dilation_d_factor) + { + fbb_.AddElement<int32_t>(Conv3DOptions::VT_DILATION_D_FACTOR, dilation_d_factor, 1); + } + void add_dilation_w_factor(int32_t dilation_w_factor) + { + fbb_.AddElement<int32_t>(Conv3DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1); + } + void add_dilation_h_factor(int32_t dilation_h_factor) + { + fbb_.AddElement<int32_t>(Conv3DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1); + } + explicit Conv3DOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<Conv3DOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<Conv3DOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<Conv3DOptions> +CreateConv3DOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::Padding padding = onert_tflite::Padding_SAME, + int32_t stride_d = 0, int32_t stride_w = 0, int32_t stride_h = 0, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE, + int32_t dilation_d_factor = 1, int32_t dilation_w_factor = 1, + int32_t dilation_h_factor = 1) +{ + Conv3DOptionsBuilder builder_(_fbb); + builder_.add_dilation_h_factor(dilation_h_factor); + builder_.add_dilation_w_factor(dilation_w_factor); + builder_.add_dilation_d_factor(dilation_d_factor); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_stride_d(stride_d); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef Pool2DOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8, + VT_FILTER_WIDTH = 10, + VT_FILTER_HEIGHT = 12, + VT_FUSED_ACTIVATION_FUNCTION = 14 + }; + onert_tflite::Padding padding() const + { + return static_cast<onert_tflite::Padding>(GetField<int8_t>(VT_PADDING, 0)); + } + int32_t stride_w() const { return GetField<int32_t>(VT_STRIDE_W, 0); } + int32_t stride_h() const { return GetField<int32_t>(VT_STRIDE_H, 0); } + int32_t filter_width() const { return GetField<int32_t>(VT_FILTER_WIDTH, 0); } + int32_t filter_height() const { return GetField<int32_t>(VT_FILTER_HEIGHT, 0); } + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_PADDING, 1) && + VerifyField<int32_t>(verifier, VT_STRIDE_W, 4) && + VerifyField<int32_t>(verifier, VT_STRIDE_H, 4) && + VerifyField<int32_t>(verifier, VT_FILTER_WIDTH, 4) && + VerifyField<int32_t>(verifier, VT_FILTER_HEIGHT, 4) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && verifier.EndTable(); + } +}; + +struct Pool2DOptionsBuilder +{ + typedef Pool2DOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_padding(onert_tflite::Padding padding) + { + fbb_.AddElement<int8_t>(Pool2DOptions::VT_PADDING, static_cast<int8_t>(padding), 0); + } + void add_stride_w(int32_t stride_w) + { + fbb_.AddElement<int32_t>(Pool2DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) + { + fbb_.AddElement<int32_t>(Pool2DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_filter_width(int32_t filter_width) + { + fbb_.AddElement<int32_t>(Pool2DOptions::VT_FILTER_WIDTH, filter_width, 0); + } + void add_filter_height(int32_t filter_height) + { + fbb_.AddElement<int32_t>(Pool2DOptions::VT_FILTER_HEIGHT, filter_height, 0); + } + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(Pool2DOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + explicit Pool2DOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<Pool2DOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<Pool2DOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<Pool2DOptions> +CreatePool2DOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::Padding padding = onert_tflite::Padding_SAME, + int32_t stride_w = 0, int32_t stride_h = 0, int32_t filter_width = 0, + int32_t filter_height = 0, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE) +{ + Pool2DOptionsBuilder builder_(_fbb); + builder_.add_filter_height(filter_height); + builder_.add_filter_width(filter_width); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef DepthwiseConv2DOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8, + VT_DEPTH_MULTIPLIER = 10, + VT_FUSED_ACTIVATION_FUNCTION = 12, + VT_DILATION_W_FACTOR = 14, + VT_DILATION_H_FACTOR = 16 + }; + onert_tflite::Padding padding() const + { + return static_cast<onert_tflite::Padding>(GetField<int8_t>(VT_PADDING, 0)); + } + int32_t stride_w() const { return GetField<int32_t>(VT_STRIDE_W, 0); } + int32_t stride_h() const { return GetField<int32_t>(VT_STRIDE_H, 0); } + int32_t depth_multiplier() const { return GetField<int32_t>(VT_DEPTH_MULTIPLIER, 0); } + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + int32_t dilation_w_factor() const { return GetField<int32_t>(VT_DILATION_W_FACTOR, 1); } + int32_t dilation_h_factor() const { return GetField<int32_t>(VT_DILATION_H_FACTOR, 1); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_PADDING, 1) && + VerifyField<int32_t>(verifier, VT_STRIDE_W, 4) && + VerifyField<int32_t>(verifier, VT_STRIDE_H, 4) && + VerifyField<int32_t>(verifier, VT_DEPTH_MULTIPLIER, 4) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR, 4) && + VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR, 4) && verifier.EndTable(); + } +}; + +struct DepthwiseConv2DOptionsBuilder +{ + typedef DepthwiseConv2DOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_padding(onert_tflite::Padding padding) + { + fbb_.AddElement<int8_t>(DepthwiseConv2DOptions::VT_PADDING, static_cast<int8_t>(padding), 0); + } + void add_stride_w(int32_t stride_w) + { + fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) + { + fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_depth_multiplier(int32_t depth_multiplier) + { + fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DEPTH_MULTIPLIER, depth_multiplier, 0); + } + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + void add_dilation_w_factor(int32_t dilation_w_factor) + { + fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1); + } + void add_dilation_h_factor(int32_t dilation_h_factor) + { + fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1); + } + explicit DepthwiseConv2DOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<DepthwiseConv2DOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<DepthwiseConv2DOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<DepthwiseConv2DOptions> +CreateDepthwiseConv2DOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::Padding padding = onert_tflite::Padding_SAME, + int32_t stride_w = 0, int32_t stride_h = 0, + int32_t depth_multiplier = 0, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE, + int32_t dilation_w_factor = 1, int32_t dilation_h_factor = 1) +{ + DepthwiseConv2DOptionsBuilder builder_(_fbb); + builder_.add_dilation_h_factor(dilation_h_factor); + builder_.add_dilation_w_factor(dilation_w_factor); + builder_.add_depth_multiplier(depth_multiplier); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ConcatEmbeddingsOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_NUM_CHANNELS = 4, + VT_NUM_COLUMNS_PER_CHANNEL = 6, + VT_EMBEDDING_DIM_PER_CHANNEL = 8 + }; + int32_t num_channels() const { return GetField<int32_t>(VT_NUM_CHANNELS, 0); } + const ::flatbuffers::Vector<int32_t> *num_columns_per_channel() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_NUM_COLUMNS_PER_CHANNEL); + } + const ::flatbuffers::Vector<int32_t> *embedding_dim_per_channel() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_EMBEDDING_DIM_PER_CHANNEL); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_NUM_CHANNELS, 4) && + VerifyOffset(verifier, VT_NUM_COLUMNS_PER_CHANNEL) && + verifier.VerifyVector(num_columns_per_channel()) && + VerifyOffset(verifier, VT_EMBEDDING_DIM_PER_CHANNEL) && + verifier.VerifyVector(embedding_dim_per_channel()) && verifier.EndTable(); + } +}; + +struct ConcatEmbeddingsOptionsBuilder +{ + typedef ConcatEmbeddingsOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_num_channels(int32_t num_channels) + { + fbb_.AddElement<int32_t>(ConcatEmbeddingsOptions::VT_NUM_CHANNELS, num_channels, 0); + } + void add_num_columns_per_channel( + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> num_columns_per_channel) + { + fbb_.AddOffset(ConcatEmbeddingsOptions::VT_NUM_COLUMNS_PER_CHANNEL, num_columns_per_channel); + } + void add_embedding_dim_per_channel( + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> embedding_dim_per_channel) + { + fbb_.AddOffset(ConcatEmbeddingsOptions::VT_EMBEDDING_DIM_PER_CHANNEL, + embedding_dim_per_channel); + } + explicit ConcatEmbeddingsOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ConcatEmbeddingsOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ConcatEmbeddingsOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ConcatEmbeddingsOptions> CreateConcatEmbeddingsOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, int32_t num_channels = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> num_columns_per_channel = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> embedding_dim_per_channel = 0) +{ + ConcatEmbeddingsOptionsBuilder builder_(_fbb); + builder_.add_embedding_dim_per_channel(embedding_dim_per_channel); + builder_.add_num_columns_per_channel(num_columns_per_channel); + builder_.add_num_channels(num_channels); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<ConcatEmbeddingsOptions> +CreateConcatEmbeddingsOptionsDirect(::flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_channels = 0, + const std::vector<int32_t> *num_columns_per_channel = nullptr, + const std::vector<int32_t> *embedding_dim_per_channel = nullptr) +{ + auto num_columns_per_channel__ = + num_columns_per_channel ? _fbb.CreateVector<int32_t>(*num_columns_per_channel) : 0; + auto embedding_dim_per_channel__ = + embedding_dim_per_channel ? _fbb.CreateVector<int32_t>(*embedding_dim_per_channel) : 0; + return onert_tflite::CreateConcatEmbeddingsOptions(_fbb, num_channels, num_columns_per_channel__, + embedding_dim_per_channel__); +} + +struct LSHProjectionOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef LSHProjectionOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_TYPE = 4 + }; + onert_tflite::LSHProjectionType type() const + { + return static_cast<onert_tflite::LSHProjectionType>(GetField<int8_t>(VT_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_TYPE, 1) && + verifier.EndTable(); + } +}; + +struct LSHProjectionOptionsBuilder +{ + typedef LSHProjectionOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_type(onert_tflite::LSHProjectionType type) + { + fbb_.AddElement<int8_t>(LSHProjectionOptions::VT_TYPE, static_cast<int8_t>(type), 0); + } + explicit LSHProjectionOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<LSHProjectionOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<LSHProjectionOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<LSHProjectionOptions> CreateLSHProjectionOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::LSHProjectionType type = onert_tflite::LSHProjectionType_UNKNOWN) +{ + LSHProjectionOptionsBuilder builder_(_fbb); + builder_.add_type(type); + return builder_.Finish(); +} + +struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SVDFOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_RANK = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 8 + }; + int32_t rank() const { return GetField<int32_t>(VT_RANK, 0); } + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool asymmetric_quantize_inputs() const + { + return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_RANK, 4) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable(); + } +}; + +struct SVDFOptionsBuilder +{ + typedef SVDFOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_rank(int32_t rank) { fbb_.AddElement<int32_t>(SVDFOptions::VT_RANK, rank, 0); } + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) + { + fbb_.AddElement<uint8_t>(SVDFOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, + static_cast<uint8_t>(asymmetric_quantize_inputs), 0); + } + explicit SVDFOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SVDFOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SVDFOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SVDFOptions> +CreateSVDFOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t rank = 0, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE, + bool asymmetric_quantize_inputs = false) +{ + SVDFOptionsBuilder builder_(_fbb); + builder_.add_rank(rank); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +struct RNNOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef RNNOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 6 + }; + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool asymmetric_quantize_inputs() const + { + return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable(); + } +}; + +struct RNNOptionsBuilder +{ + typedef RNNOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) + { + fbb_.AddElement<uint8_t>(RNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, + static_cast<uint8_t>(asymmetric_quantize_inputs), 0); + } + explicit RNNOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<RNNOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<RNNOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<RNNOptions> +CreateRNNOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE, + bool asymmetric_quantize_inputs = false) +{ + RNNOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SequenceRNNOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_TIME_MAJOR = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 8 + }; + bool time_major() const { return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; } + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool asymmetric_quantize_inputs() const + { + return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR, 1) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable(); + } +}; + +struct SequenceRNNOptionsBuilder +{ + typedef SequenceRNNOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_time_major(bool time_major) + { + fbb_.AddElement<uint8_t>(SequenceRNNOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), + 0); + } + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) + { + fbb_.AddElement<uint8_t>(SequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, + static_cast<uint8_t>(asymmetric_quantize_inputs), 0); + } + explicit SequenceRNNOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SequenceRNNOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SequenceRNNOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SequenceRNNOptions> +CreateSequenceRNNOptions(::flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE, + bool asymmetric_quantize_inputs = false) +{ + SequenceRNNOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_time_major(time_major); + return builder_.Finish(); +} + +struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef BidirectionalSequenceRNNOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_TIME_MAJOR = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6, + VT_MERGE_OUTPUTS = 8, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 10 + }; + bool time_major() const { return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; } + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool merge_outputs() const { return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0; } + bool asymmetric_quantize_inputs() const + { + return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR, 1) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS, 1) && + VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable(); + } +}; + +struct BidirectionalSequenceRNNOptionsBuilder +{ + typedef BidirectionalSequenceRNNOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_time_major(bool time_major) + { + fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_TIME_MAJOR, + static_cast<uint8_t>(time_major), 0); + } + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(BidirectionalSequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + void add_merge_outputs(bool merge_outputs) + { + fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, + static_cast<uint8_t>(merge_outputs), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) + { + fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, + static_cast<uint8_t>(asymmetric_quantize_inputs), 0); + } + explicit BidirectionalSequenceRNNOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<BidirectionalSequenceRNNOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<BidirectionalSequenceRNNOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalSequenceRNNOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE, + bool merge_outputs = false, bool asymmetric_quantize_inputs = false) +{ + BidirectionalSequenceRNNOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_merge_outputs(merge_outputs); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_time_major(time_major); + return builder_.Finish(); +} + +struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef FullyConnectedOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_WEIGHTS_FORMAT = 6, + VT_KEEP_NUM_DIMS = 8, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 10 + }; + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + onert_tflite::FullyConnectedOptionsWeightsFormat weights_format() const + { + return static_cast<onert_tflite::FullyConnectedOptionsWeightsFormat>( + GetField<int8_t>(VT_WEIGHTS_FORMAT, 0)); + } + bool keep_num_dims() const { return GetField<uint8_t>(VT_KEEP_NUM_DIMS, 0) != 0; } + bool asymmetric_quantize_inputs() const + { + return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField<int8_t>(verifier, VT_WEIGHTS_FORMAT, 1) && + VerifyField<uint8_t>(verifier, VT_KEEP_NUM_DIMS, 1) && + VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable(); + } +}; + +struct FullyConnectedOptionsBuilder +{ + typedef FullyConnectedOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + void add_weights_format(onert_tflite::FullyConnectedOptionsWeightsFormat weights_format) + { + fbb_.AddElement<int8_t>(FullyConnectedOptions::VT_WEIGHTS_FORMAT, + static_cast<int8_t>(weights_format), 0); + } + void add_keep_num_dims(bool keep_num_dims) + { + fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_KEEP_NUM_DIMS, + static_cast<uint8_t>(keep_num_dims), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) + { + fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, + static_cast<uint8_t>(asymmetric_quantize_inputs), 0); + } + explicit FullyConnectedOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<FullyConnectedOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<FullyConnectedOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<FullyConnectedOptions> +CreateFullyConnectedOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE, + onert_tflite::FullyConnectedOptionsWeightsFormat weights_format = + onert_tflite::FullyConnectedOptionsWeightsFormat_DEFAULT, + bool keep_num_dims = false, bool asymmetric_quantize_inputs = false) +{ + FullyConnectedOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_keep_num_dims(keep_num_dims); + builder_.add_weights_format(weights_format); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +struct SoftmaxOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SoftmaxOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_BETA = 4 + }; + float beta() const { return GetField<float>(VT_BETA, 0.0f); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<float>(verifier, VT_BETA, 4) && + verifier.EndTable(); + } +}; + +struct SoftmaxOptionsBuilder +{ + typedef SoftmaxOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_beta(float beta) { fbb_.AddElement<float>(SoftmaxOptions::VT_BETA, beta, 0.0f); } + explicit SoftmaxOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SoftmaxOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SoftmaxOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SoftmaxOptions> +CreateSoftmaxOptions(::flatbuffers::FlatBufferBuilder &_fbb, float beta = 0.0f) +{ + SoftmaxOptionsBuilder builder_(_fbb); + builder_.add_beta(beta); + return builder_.Finish(); +} + +struct ConcatenationOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ConcatenationOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_AXIS = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6 + }; + int32_t axis() const { return GetField<int32_t>(VT_AXIS, 0); } + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_AXIS, 4) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && verifier.EndTable(); + } +}; + +struct ConcatenationOptionsBuilder +{ + typedef ConcatenationOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { fbb_.AddElement<int32_t>(ConcatenationOptions::VT_AXIS, axis, 0); } + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(ConcatenationOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + explicit ConcatenationOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ConcatenationOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ConcatenationOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ConcatenationOptions> +CreateConcatenationOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t axis = 0, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE) +{ + ConcatenationOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +struct AddOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef AddOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_POT_SCALE_INT16 = 6 + }; + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool pot_scale_int16() const { return GetField<uint8_t>(VT_POT_SCALE_INT16, 1) != 0; } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField<uint8_t>(verifier, VT_POT_SCALE_INT16, 1) && verifier.EndTable(); + } +}; + +struct AddOptionsBuilder +{ + typedef AddOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(AddOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + void add_pot_scale_int16(bool pot_scale_int16) + { + fbb_.AddElement<uint8_t>(AddOptions::VT_POT_SCALE_INT16, static_cast<uint8_t>(pot_scale_int16), + 1); + } + explicit AddOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<AddOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<AddOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<AddOptions> +CreateAddOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE, + bool pot_scale_int16 = true) +{ + AddOptionsBuilder builder_(_fbb); + builder_.add_pot_scale_int16(pot_scale_int16); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +struct MulOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef MulOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && verifier.EndTable(); + } +}; + +struct MulOptionsBuilder +{ + typedef MulOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(MulOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + explicit MulOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<MulOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<MulOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<MulOptions> +CreateMulOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE) +{ + MulOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +struct L2NormOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef L2NormOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && verifier.EndTable(); + } +}; + +struct L2NormOptionsBuilder +{ + typedef L2NormOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(L2NormOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + explicit L2NormOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<L2NormOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<L2NormOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<L2NormOptions> +CreateL2NormOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE) +{ + L2NormOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +struct LocalResponseNormalizationOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef LocalResponseNormalizationOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_RADIUS = 4, + VT_BIAS = 6, + VT_ALPHA = 8, + VT_BETA = 10 + }; + int32_t radius() const { return GetField<int32_t>(VT_RADIUS, 0); } + float bias() const { return GetField<float>(VT_BIAS, 0.0f); } + float alpha() const { return GetField<float>(VT_ALPHA, 0.0f); } + float beta() const { return GetField<float>(VT_BETA, 0.0f); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_RADIUS, 4) && + VerifyField<float>(verifier, VT_BIAS, 4) && VerifyField<float>(verifier, VT_ALPHA, 4) && + VerifyField<float>(verifier, VT_BETA, 4) && verifier.EndTable(); + } +}; + +struct LocalResponseNormalizationOptionsBuilder +{ + typedef LocalResponseNormalizationOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_radius(int32_t radius) + { + fbb_.AddElement<int32_t>(LocalResponseNormalizationOptions::VT_RADIUS, radius, 0); + } + void add_bias(float bias) + { + fbb_.AddElement<float>(LocalResponseNormalizationOptions::VT_BIAS, bias, 0.0f); + } + void add_alpha(float alpha) + { + fbb_.AddElement<float>(LocalResponseNormalizationOptions::VT_ALPHA, alpha, 0.0f); + } + void add_beta(float beta) + { + fbb_.AddElement<float>(LocalResponseNormalizationOptions::VT_BETA, beta, 0.0f); + } + explicit LocalResponseNormalizationOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<LocalResponseNormalizationOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<LocalResponseNormalizationOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<LocalResponseNormalizationOptions> +CreateLocalResponseNormalizationOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t radius = 0, + float bias = 0.0f, float alpha = 0.0f, float beta = 0.0f) +{ + LocalResponseNormalizationOptionsBuilder builder_(_fbb); + builder_.add_beta(beta); + builder_.add_alpha(alpha); + builder_.add_bias(bias); + builder_.add_radius(radius); + return builder_.Finish(); +} + +struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef LSTMOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_CELL_CLIP = 6, + VT_PROJ_CLIP = 8, + VT_KERNEL_TYPE = 10, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 12 + }; + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + float cell_clip() const { return GetField<float>(VT_CELL_CLIP, 0.0f); } + float proj_clip() const { return GetField<float>(VT_PROJ_CLIP, 0.0f); } + onert_tflite::LSTMKernelType kernel_type() const + { + return static_cast<onert_tflite::LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0)); + } + bool asymmetric_quantize_inputs() const + { + return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField<float>(verifier, VT_CELL_CLIP, 4) && + VerifyField<float>(verifier, VT_PROJ_CLIP, 4) && + VerifyField<int8_t>(verifier, VT_KERNEL_TYPE, 1) && + VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable(); + } +}; + +struct LSTMOptionsBuilder +{ + typedef LSTMOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(LSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + void add_cell_clip(float cell_clip) + { + fbb_.AddElement<float>(LSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f); + } + void add_proj_clip(float proj_clip) + { + fbb_.AddElement<float>(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); + } + void add_kernel_type(onert_tflite::LSTMKernelType kernel_type) + { + fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) + { + fbb_.AddElement<uint8_t>(LSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, + static_cast<uint8_t>(asymmetric_quantize_inputs), 0); + } + explicit LSTMOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<LSTMOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<LSTMOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<LSTMOptions> +CreateLSTMOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE, + float cell_clip = 0.0f, float proj_clip = 0.0f, + onert_tflite::LSTMKernelType kernel_type = onert_tflite::LSTMKernelType_FULL, + bool asymmetric_quantize_inputs = false) +{ + LSTMOptionsBuilder builder_(_fbb); + builder_.add_proj_clip(proj_clip); + builder_.add_cell_clip(cell_clip); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_kernel_type(kernel_type); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef UnidirectionalSequenceLSTMOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_CELL_CLIP = 6, + VT_PROJ_CLIP = 8, + VT_TIME_MAJOR = 10, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 12 + }; + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + float cell_clip() const { return GetField<float>(VT_CELL_CLIP, 0.0f); } + float proj_clip() const { return GetField<float>(VT_PROJ_CLIP, 0.0f); } + bool time_major() const { return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; } + bool asymmetric_quantize_inputs() const + { + return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField<float>(verifier, VT_CELL_CLIP, 4) && + VerifyField<float>(verifier, VT_PROJ_CLIP, 4) && + VerifyField<uint8_t>(verifier, VT_TIME_MAJOR, 1) && + VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable(); + } +}; + +struct UnidirectionalSequenceLSTMOptionsBuilder +{ + typedef UnidirectionalSequenceLSTMOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(UnidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + void add_cell_clip(float cell_clip) + { + fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f); + } + void add_proj_clip(float proj_clip) + { + fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); + } + void add_time_major(bool time_major) + { + fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, + static_cast<uint8_t>(time_major), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) + { + fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, + static_cast<uint8_t>(asymmetric_quantize_inputs), 0); + } + explicit UnidirectionalSequenceLSTMOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<UnidirectionalSequenceLSTMOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> +CreateUnidirectionalSequenceLSTMOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE, + float cell_clip = 0.0f, float proj_clip = 0.0f, bool time_major = false, + bool asymmetric_quantize_inputs = false) +{ + UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); + builder_.add_proj_clip(proj_clip); + builder_.add_cell_clip(cell_clip); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_time_major(time_major); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef BidirectionalSequenceLSTMOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_CELL_CLIP = 6, + VT_PROJ_CLIP = 8, + VT_MERGE_OUTPUTS = 10, + VT_TIME_MAJOR = 12, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 14 + }; + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + float cell_clip() const { return GetField<float>(VT_CELL_CLIP, 0.0f); } + float proj_clip() const { return GetField<float>(VT_PROJ_CLIP, 0.0f); } + bool merge_outputs() const { return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0; } + bool time_major() const { return GetField<uint8_t>(VT_TIME_MAJOR, 1) != 0; } + bool asymmetric_quantize_inputs() const + { + return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField<float>(verifier, VT_CELL_CLIP, 4) && + VerifyField<float>(verifier, VT_PROJ_CLIP, 4) && + VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS, 1) && + VerifyField<uint8_t>(verifier, VT_TIME_MAJOR, 1) && + VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable(); + } +}; + +struct BidirectionalSequenceLSTMOptionsBuilder +{ + typedef BidirectionalSequenceLSTMOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(BidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + void add_cell_clip(float cell_clip) + { + fbb_.AddElement<float>(BidirectionalSequenceLSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f); + } + void add_proj_clip(float proj_clip) + { + fbb_.AddElement<float>(BidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); + } + void add_merge_outputs(bool merge_outputs) + { + fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_MERGE_OUTPUTS, + static_cast<uint8_t>(merge_outputs), 0); + } + void add_time_major(bool time_major) + { + fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, + static_cast<uint8_t>(time_major), 1); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) + { + fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, + static_cast<uint8_t>(asymmetric_quantize_inputs), 0); + } + explicit BidirectionalSequenceLSTMOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<BidirectionalSequenceLSTMOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<BidirectionalSequenceLSTMOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<BidirectionalSequenceLSTMOptions> +CreateBidirectionalSequenceLSTMOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE, + float cell_clip = 0.0f, float proj_clip = 0.0f, bool merge_outputs = false, + bool time_major = true, bool asymmetric_quantize_inputs = false) +{ + BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); + builder_.add_proj_clip(proj_clip); + builder_.add_cell_clip(cell_clip); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_time_major(time_major); + builder_.add_merge_outputs(merge_outputs); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ResizeBilinearOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_ALIGN_CORNERS = 8, + VT_HALF_PIXEL_CENTERS = 10 + }; + bool align_corners() const { return GetField<uint8_t>(VT_ALIGN_CORNERS, 0) != 0; } + bool half_pixel_centers() const { return GetField<uint8_t>(VT_HALF_PIXEL_CENTERS, 0) != 0; } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_ALIGN_CORNERS, 1) && + VerifyField<uint8_t>(verifier, VT_HALF_PIXEL_CENTERS, 1) && verifier.EndTable(); + } +}; + +struct ResizeBilinearOptionsBuilder +{ + typedef ResizeBilinearOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_align_corners(bool align_corners) + { + fbb_.AddElement<uint8_t>(ResizeBilinearOptions::VT_ALIGN_CORNERS, + static_cast<uint8_t>(align_corners), 0); + } + void add_half_pixel_centers(bool half_pixel_centers) + { + fbb_.AddElement<uint8_t>(ResizeBilinearOptions::VT_HALF_PIXEL_CENTERS, + static_cast<uint8_t>(half_pixel_centers), 0); + } + explicit ResizeBilinearOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ResizeBilinearOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ResizeBilinearOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ResizeBilinearOptions> +CreateResizeBilinearOptions(::flatbuffers::FlatBufferBuilder &_fbb, bool align_corners = false, + bool half_pixel_centers = false) +{ + ResizeBilinearOptionsBuilder builder_(_fbb); + builder_.add_half_pixel_centers(half_pixel_centers); + builder_.add_align_corners(align_corners); + return builder_.Finish(); +} + +struct ResizeNearestNeighborOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ResizeNearestNeighborOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_ALIGN_CORNERS = 4, + VT_HALF_PIXEL_CENTERS = 6 + }; + bool align_corners() const { return GetField<uint8_t>(VT_ALIGN_CORNERS, 0) != 0; } + bool half_pixel_centers() const { return GetField<uint8_t>(VT_HALF_PIXEL_CENTERS, 0) != 0; } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_ALIGN_CORNERS, 1) && + VerifyField<uint8_t>(verifier, VT_HALF_PIXEL_CENTERS, 1) && verifier.EndTable(); + } +}; + +struct ResizeNearestNeighborOptionsBuilder +{ + typedef ResizeNearestNeighborOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_align_corners(bool align_corners) + { + fbb_.AddElement<uint8_t>(ResizeNearestNeighborOptions::VT_ALIGN_CORNERS, + static_cast<uint8_t>(align_corners), 0); + } + void add_half_pixel_centers(bool half_pixel_centers) + { + fbb_.AddElement<uint8_t>(ResizeNearestNeighborOptions::VT_HALF_PIXEL_CENTERS, + static_cast<uint8_t>(half_pixel_centers), 0); + } + explicit ResizeNearestNeighborOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ResizeNearestNeighborOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ResizeNearestNeighborOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ResizeNearestNeighborOptions> +CreateResizeNearestNeighborOptions(::flatbuffers::FlatBufferBuilder &_fbb, + bool align_corners = false, bool half_pixel_centers = false) +{ + ResizeNearestNeighborOptionsBuilder builder_(_fbb); + builder_.add_half_pixel_centers(half_pixel_centers); + builder_.add_align_corners(align_corners); + return builder_.Finish(); +} + +struct CallOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef CallOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_SUBGRAPH = 4 + }; + uint32_t subgraph() const { return GetField<uint32_t>(VT_SUBGRAPH, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint32_t>(verifier, VT_SUBGRAPH, 4) && + verifier.EndTable(); + } +}; + +struct CallOptionsBuilder +{ + typedef CallOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_subgraph(uint32_t subgraph) + { + fbb_.AddElement<uint32_t>(CallOptions::VT_SUBGRAPH, subgraph, 0); + } + explicit CallOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<CallOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<CallOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<CallOptions> CreateCallOptions(::flatbuffers::FlatBufferBuilder &_fbb, + uint32_t subgraph = 0) +{ + CallOptionsBuilder builder_(_fbb); + builder_.add_subgraph(subgraph); + return builder_.Finish(); +} + +struct PadOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef PadOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct PadOptionsBuilder +{ + typedef PadOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit PadOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<PadOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<PadOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<PadOptions> CreatePadOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + PadOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct PadV2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef PadV2OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct PadV2OptionsBuilder +{ + typedef PadV2Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit PadV2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<PadV2Options> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<PadV2Options>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<PadV2Options> +CreatePadV2Options(::flatbuffers::FlatBufferBuilder &_fbb) +{ + PadV2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct ReshapeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ReshapeOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_NEW_SHAPE = 4 + }; + const ::flatbuffers::Vector<int32_t> *new_shape() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_NEW_SHAPE); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NEW_SHAPE) && + verifier.VerifyVector(new_shape()) && verifier.EndTable(); + } +}; + +struct ReshapeOptionsBuilder +{ + typedef ReshapeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_new_shape(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> new_shape) + { + fbb_.AddOffset(ReshapeOptions::VT_NEW_SHAPE, new_shape); + } + explicit ReshapeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ReshapeOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ReshapeOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ReshapeOptions> +CreateReshapeOptions(::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> new_shape = 0) +{ + ReshapeOptionsBuilder builder_(_fbb); + builder_.add_new_shape(new_shape); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<ReshapeOptions> +CreateReshapeOptionsDirect(::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<int32_t> *new_shape = nullptr) +{ + auto new_shape__ = new_shape ? _fbb.CreateVector<int32_t>(*new_shape) : 0; + return onert_tflite::CreateReshapeOptions(_fbb, new_shape__); +} + +struct SpaceToBatchNDOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SpaceToBatchNDOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct SpaceToBatchNDOptionsBuilder +{ + typedef SpaceToBatchNDOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SpaceToBatchNDOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SpaceToBatchNDOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SpaceToBatchNDOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SpaceToBatchNDOptions> +CreateSpaceToBatchNDOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + SpaceToBatchNDOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef BatchToSpaceNDOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct BatchToSpaceNDOptionsBuilder +{ + typedef BatchToSpaceNDOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit BatchToSpaceNDOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<BatchToSpaceNDOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<BatchToSpaceNDOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<BatchToSpaceNDOptions> +CreateBatchToSpaceNDOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + BatchToSpaceNDOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct SkipGramOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SkipGramOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_NGRAM_SIZE = 4, + VT_MAX_SKIP_SIZE = 6, + VT_INCLUDE_ALL_NGRAMS = 8 + }; + int32_t ngram_size() const { return GetField<int32_t>(VT_NGRAM_SIZE, 0); } + int32_t max_skip_size() const { return GetField<int32_t>(VT_MAX_SKIP_SIZE, 0); } + bool include_all_ngrams() const { return GetField<uint8_t>(VT_INCLUDE_ALL_NGRAMS, 0) != 0; } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_NGRAM_SIZE, 4) && + VerifyField<int32_t>(verifier, VT_MAX_SKIP_SIZE, 4) && + VerifyField<uint8_t>(verifier, VT_INCLUDE_ALL_NGRAMS, 1) && verifier.EndTable(); + } +}; + +struct SkipGramOptionsBuilder +{ + typedef SkipGramOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_ngram_size(int32_t ngram_size) + { + fbb_.AddElement<int32_t>(SkipGramOptions::VT_NGRAM_SIZE, ngram_size, 0); + } + void add_max_skip_size(int32_t max_skip_size) + { + fbb_.AddElement<int32_t>(SkipGramOptions::VT_MAX_SKIP_SIZE, max_skip_size, 0); + } + void add_include_all_ngrams(bool include_all_ngrams) + { + fbb_.AddElement<uint8_t>(SkipGramOptions::VT_INCLUDE_ALL_NGRAMS, + static_cast<uint8_t>(include_all_ngrams), 0); + } + explicit SkipGramOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SkipGramOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SkipGramOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SkipGramOptions> +CreateSkipGramOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t ngram_size = 0, + int32_t max_skip_size = 0, bool include_all_ngrams = false) +{ + SkipGramOptionsBuilder builder_(_fbb); + builder_.add_max_skip_size(max_skip_size); + builder_.add_ngram_size(ngram_size); + builder_.add_include_all_ngrams(include_all_ngrams); + return builder_.Finish(); +} + +struct SpaceToDepthOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SpaceToDepthOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_BLOCK_SIZE = 4 + }; + int32_t block_size() const { return GetField<int32_t>(VT_BLOCK_SIZE, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_BLOCK_SIZE, 4) && + verifier.EndTable(); + } +}; + +struct SpaceToDepthOptionsBuilder +{ + typedef SpaceToDepthOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_block_size(int32_t block_size) + { + fbb_.AddElement<int32_t>(SpaceToDepthOptions::VT_BLOCK_SIZE, block_size, 0); + } + explicit SpaceToDepthOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SpaceToDepthOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SpaceToDepthOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SpaceToDepthOptions> +CreateSpaceToDepthOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t block_size = 0) +{ + SpaceToDepthOptionsBuilder builder_(_fbb); + builder_.add_block_size(block_size); + return builder_.Finish(); +} + +struct DepthToSpaceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef DepthToSpaceOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_BLOCK_SIZE = 4 + }; + int32_t block_size() const { return GetField<int32_t>(VT_BLOCK_SIZE, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_BLOCK_SIZE, 4) && + verifier.EndTable(); + } +}; + +struct DepthToSpaceOptionsBuilder +{ + typedef DepthToSpaceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_block_size(int32_t block_size) + { + fbb_.AddElement<int32_t>(DepthToSpaceOptions::VT_BLOCK_SIZE, block_size, 0); + } + explicit DepthToSpaceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<DepthToSpaceOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<DepthToSpaceOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<DepthToSpaceOptions> +CreateDepthToSpaceOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t block_size = 0) +{ + DepthToSpaceOptionsBuilder builder_(_fbb); + builder_.add_block_size(block_size); + return builder_.Finish(); +} + +struct SubOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SubOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_POT_SCALE_INT16 = 6 + }; + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool pot_scale_int16() const { return GetField<uint8_t>(VT_POT_SCALE_INT16, 1) != 0; } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField<uint8_t>(verifier, VT_POT_SCALE_INT16, 1) && verifier.EndTable(); + } +}; + +struct SubOptionsBuilder +{ + typedef SubOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(SubOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + void add_pot_scale_int16(bool pot_scale_int16) + { + fbb_.AddElement<uint8_t>(SubOptions::VT_POT_SCALE_INT16, static_cast<uint8_t>(pot_scale_int16), + 1); + } + explicit SubOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SubOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SubOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SubOptions> +CreateSubOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE, + bool pot_scale_int16 = true) +{ + SubOptionsBuilder builder_(_fbb); + builder_.add_pot_scale_int16(pot_scale_int16); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +struct DivOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef DivOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; + onert_tflite::ActivationFunctionType fused_activation_function() const + { + return static_cast<onert_tflite::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && verifier.EndTable(); + } +}; + +struct DivOptionsBuilder +{ + typedef DivOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(DivOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + explicit DivOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<DivOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<DivOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<DivOptions> +CreateDivOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::ActivationFunctionType fused_activation_function = + onert_tflite::ActivationFunctionType_NONE) +{ + DivOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +struct TopKV2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef TopKV2OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct TopKV2OptionsBuilder +{ + typedef TopKV2Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit TopKV2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<TopKV2Options> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<TopKV2Options>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<TopKV2Options> +CreateTopKV2Options(::flatbuffers::FlatBufferBuilder &_fbb) +{ + TopKV2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct EmbeddingLookupSparseOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef EmbeddingLookupSparseOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_COMBINER = 4 + }; + onert_tflite::CombinerType combiner() const + { + return static_cast<onert_tflite::CombinerType>(GetField<int8_t>(VT_COMBINER, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_COMBINER, 1) && + verifier.EndTable(); + } +}; + +struct EmbeddingLookupSparseOptionsBuilder +{ + typedef EmbeddingLookupSparseOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_combiner(onert_tflite::CombinerType combiner) + { + fbb_.AddElement<int8_t>(EmbeddingLookupSparseOptions::VT_COMBINER, + static_cast<int8_t>(combiner), 0); + } + explicit EmbeddingLookupSparseOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<EmbeddingLookupSparseOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<EmbeddingLookupSparseOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<EmbeddingLookupSparseOptions> CreateEmbeddingLookupSparseOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::CombinerType combiner = onert_tflite::CombinerType_SUM) +{ + EmbeddingLookupSparseOptionsBuilder builder_(_fbb); + builder_.add_combiner(combiner); + return builder_.Finish(); +} + +struct GatherOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef GatherOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_AXIS = 4, + VT_BATCH_DIMS = 6 + }; + int32_t axis() const { return GetField<int32_t>(VT_AXIS, 0); } + int32_t batch_dims() const { return GetField<int32_t>(VT_BATCH_DIMS, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_AXIS, 4) && + VerifyField<int32_t>(verifier, VT_BATCH_DIMS, 4) && verifier.EndTable(); + } +}; + +struct GatherOptionsBuilder +{ + typedef GatherOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { fbb_.AddElement<int32_t>(GatherOptions::VT_AXIS, axis, 0); } + void add_batch_dims(int32_t batch_dims) + { + fbb_.AddElement<int32_t>(GatherOptions::VT_BATCH_DIMS, batch_dims, 0); + } + explicit GatherOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<GatherOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<GatherOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<GatherOptions> +CreateGatherOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t axis = 0, + int32_t batch_dims = 0) +{ + GatherOptionsBuilder builder_(_fbb); + builder_.add_batch_dims(batch_dims); + builder_.add_axis(axis); + return builder_.Finish(); +} + +struct TransposeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef TransposeOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct TransposeOptionsBuilder +{ + typedef TransposeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit TransposeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<TransposeOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<TransposeOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<TransposeOptions> +CreateTransposeOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + TransposeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct ExpOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ExpOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct ExpOptionsBuilder +{ + typedef ExpOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ExpOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ExpOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ExpOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ExpOptions> CreateExpOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + ExpOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct CosOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef CosOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct CosOptionsBuilder +{ + typedef CosOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit CosOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<CosOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<CosOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<CosOptions> CreateCosOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + CosOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct ReducerOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ReducerOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_KEEP_DIMS = 4 + }; + bool keep_dims() const { return GetField<uint8_t>(VT_KEEP_DIMS, 0) != 0; } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_KEEP_DIMS, 1) && + verifier.EndTable(); + } +}; + +struct ReducerOptionsBuilder +{ + typedef ReducerOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_keep_dims(bool keep_dims) + { + fbb_.AddElement<uint8_t>(ReducerOptions::VT_KEEP_DIMS, static_cast<uint8_t>(keep_dims), 0); + } + explicit ReducerOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ReducerOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ReducerOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ReducerOptions> +CreateReducerOptions(::flatbuffers::FlatBufferBuilder &_fbb, bool keep_dims = false) +{ + ReducerOptionsBuilder builder_(_fbb); + builder_.add_keep_dims(keep_dims); + return builder_.Finish(); +} + +struct SqueezeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SqueezeOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_SQUEEZE_DIMS = 4 + }; + const ::flatbuffers::Vector<int32_t> *squeeze_dims() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_SQUEEZE_DIMS); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SQUEEZE_DIMS) && + verifier.VerifyVector(squeeze_dims()) && verifier.EndTable(); + } +}; + +struct SqueezeOptionsBuilder +{ + typedef SqueezeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_squeeze_dims(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> squeeze_dims) + { + fbb_.AddOffset(SqueezeOptions::VT_SQUEEZE_DIMS, squeeze_dims); + } + explicit SqueezeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SqueezeOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SqueezeOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SqueezeOptions> +CreateSqueezeOptions(::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> squeeze_dims = 0) +{ + SqueezeOptionsBuilder builder_(_fbb); + builder_.add_squeeze_dims(squeeze_dims); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<SqueezeOptions> +CreateSqueezeOptionsDirect(::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<int32_t> *squeeze_dims = nullptr) +{ + auto squeeze_dims__ = squeeze_dims ? _fbb.CreateVector<int32_t>(*squeeze_dims) : 0; + return onert_tflite::CreateSqueezeOptions(_fbb, squeeze_dims__); +} + +struct SplitOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SplitOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_NUM_SPLITS = 4 + }; + int32_t num_splits() const { return GetField<int32_t>(VT_NUM_SPLITS, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_NUM_SPLITS, 4) && + verifier.EndTable(); + } +}; + +struct SplitOptionsBuilder +{ + typedef SplitOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_num_splits(int32_t num_splits) + { + fbb_.AddElement<int32_t>(SplitOptions::VT_NUM_SPLITS, num_splits, 0); + } + explicit SplitOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SplitOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SplitOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SplitOptions> +CreateSplitOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t num_splits = 0) +{ + SplitOptionsBuilder builder_(_fbb); + builder_.add_num_splits(num_splits); + return builder_.Finish(); +} + +struct SplitVOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SplitVOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_NUM_SPLITS = 4 + }; + int32_t num_splits() const { return GetField<int32_t>(VT_NUM_SPLITS, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_NUM_SPLITS, 4) && + verifier.EndTable(); + } +}; + +struct SplitVOptionsBuilder +{ + typedef SplitVOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_num_splits(int32_t num_splits) + { + fbb_.AddElement<int32_t>(SplitVOptions::VT_NUM_SPLITS, num_splits, 0); + } + explicit SplitVOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SplitVOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SplitVOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SplitVOptions> +CreateSplitVOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t num_splits = 0) +{ + SplitVOptionsBuilder builder_(_fbb); + builder_.add_num_splits(num_splits); + return builder_.Finish(); +} + +struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef StridedSliceOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_BEGIN_MASK = 4, + VT_END_MASK = 6, + VT_ELLIPSIS_MASK = 8, + VT_NEW_AXIS_MASK = 10, + VT_SHRINK_AXIS_MASK = 12 + }; + int32_t begin_mask() const { return GetField<int32_t>(VT_BEGIN_MASK, 0); } + int32_t end_mask() const { return GetField<int32_t>(VT_END_MASK, 0); } + int32_t ellipsis_mask() const { return GetField<int32_t>(VT_ELLIPSIS_MASK, 0); } + int32_t new_axis_mask() const { return GetField<int32_t>(VT_NEW_AXIS_MASK, 0); } + int32_t shrink_axis_mask() const { return GetField<int32_t>(VT_SHRINK_AXIS_MASK, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_BEGIN_MASK, 4) && + VerifyField<int32_t>(verifier, VT_END_MASK, 4) && + VerifyField<int32_t>(verifier, VT_ELLIPSIS_MASK, 4) && + VerifyField<int32_t>(verifier, VT_NEW_AXIS_MASK, 4) && + VerifyField<int32_t>(verifier, VT_SHRINK_AXIS_MASK, 4) && verifier.EndTable(); + } +}; + +struct StridedSliceOptionsBuilder +{ + typedef StridedSliceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_begin_mask(int32_t begin_mask) + { + fbb_.AddElement<int32_t>(StridedSliceOptions::VT_BEGIN_MASK, begin_mask, 0); + } + void add_end_mask(int32_t end_mask) + { + fbb_.AddElement<int32_t>(StridedSliceOptions::VT_END_MASK, end_mask, 0); + } + void add_ellipsis_mask(int32_t ellipsis_mask) + { + fbb_.AddElement<int32_t>(StridedSliceOptions::VT_ELLIPSIS_MASK, ellipsis_mask, 0); + } + void add_new_axis_mask(int32_t new_axis_mask) + { + fbb_.AddElement<int32_t>(StridedSliceOptions::VT_NEW_AXIS_MASK, new_axis_mask, 0); + } + void add_shrink_axis_mask(int32_t shrink_axis_mask) + { + fbb_.AddElement<int32_t>(StridedSliceOptions::VT_SHRINK_AXIS_MASK, shrink_axis_mask, 0); + } + explicit StridedSliceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<StridedSliceOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<StridedSliceOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<StridedSliceOptions> +CreateStridedSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t begin_mask = 0, + int32_t end_mask = 0, int32_t ellipsis_mask = 0, + int32_t new_axis_mask = 0, int32_t shrink_axis_mask = 0) +{ + StridedSliceOptionsBuilder builder_(_fbb); + builder_.add_shrink_axis_mask(shrink_axis_mask); + builder_.add_new_axis_mask(new_axis_mask); + builder_.add_ellipsis_mask(ellipsis_mask); + builder_.add_end_mask(end_mask); + builder_.add_begin_mask(begin_mask); + return builder_.Finish(); +} + +struct LogSoftmaxOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef LogSoftmaxOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct LogSoftmaxOptionsBuilder +{ + typedef LogSoftmaxOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit LogSoftmaxOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<LogSoftmaxOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<LogSoftmaxOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<LogSoftmaxOptions> +CreateLogSoftmaxOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + LogSoftmaxOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct CastOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef CastOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_IN_DATA_TYPE = 4, + VT_OUT_DATA_TYPE = 6 + }; + onert_tflite::TensorType in_data_type() const + { + return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_IN_DATA_TYPE, 0)); + } + onert_tflite::TensorType out_data_type() const + { + return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_OUT_DATA_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_IN_DATA_TYPE, 1) && + VerifyField<int8_t>(verifier, VT_OUT_DATA_TYPE, 1) && verifier.EndTable(); + } +}; + +struct CastOptionsBuilder +{ + typedef CastOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_in_data_type(onert_tflite::TensorType in_data_type) + { + fbb_.AddElement<int8_t>(CastOptions::VT_IN_DATA_TYPE, static_cast<int8_t>(in_data_type), 0); + } + void add_out_data_type(onert_tflite::TensorType out_data_type) + { + fbb_.AddElement<int8_t>(CastOptions::VT_OUT_DATA_TYPE, static_cast<int8_t>(out_data_type), 0); + } + explicit CastOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<CastOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<CastOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<CastOptions> +CreateCastOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::TensorType in_data_type = onert_tflite::TensorType_FLOAT32, + onert_tflite::TensorType out_data_type = onert_tflite::TensorType_FLOAT32) +{ + CastOptionsBuilder builder_(_fbb); + builder_.add_out_data_type(out_data_type); + builder_.add_in_data_type(in_data_type); + return builder_.Finish(); +} + +struct DequantizeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef DequantizeOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct DequantizeOptionsBuilder +{ + typedef DequantizeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit DequantizeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<DequantizeOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<DequantizeOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<DequantizeOptions> +CreateDequantizeOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + DequantizeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct MaximumMinimumOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef MaximumMinimumOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct MaximumMinimumOptionsBuilder +{ + typedef MaximumMinimumOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit MaximumMinimumOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<MaximumMinimumOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<MaximumMinimumOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<MaximumMinimumOptions> +CreateMaximumMinimumOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + MaximumMinimumOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct TileOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef TileOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct TileOptionsBuilder +{ + typedef TileOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit TileOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<TileOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<TileOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<TileOptions> CreateTileOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + TileOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct ArgMaxOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ArgMaxOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_OUTPUT_TYPE = 4 + }; + onert_tflite::TensorType output_type() const + { + return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_OUTPUT_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_OUTPUT_TYPE, 1) && + verifier.EndTable(); + } +}; + +struct ArgMaxOptionsBuilder +{ + typedef ArgMaxOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_output_type(onert_tflite::TensorType output_type) + { + fbb_.AddElement<int8_t>(ArgMaxOptions::VT_OUTPUT_TYPE, static_cast<int8_t>(output_type), 0); + } + explicit ArgMaxOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ArgMaxOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ArgMaxOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ArgMaxOptions> +CreateArgMaxOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::TensorType output_type = onert_tflite::TensorType_FLOAT32) +{ + ArgMaxOptionsBuilder builder_(_fbb); + builder_.add_output_type(output_type); + return builder_.Finish(); +} + +struct ArgMinOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ArgMinOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_OUTPUT_TYPE = 4 + }; + onert_tflite::TensorType output_type() const + { + return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_OUTPUT_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_OUTPUT_TYPE, 1) && + verifier.EndTable(); + } +}; + +struct ArgMinOptionsBuilder +{ + typedef ArgMinOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_output_type(onert_tflite::TensorType output_type) + { + fbb_.AddElement<int8_t>(ArgMinOptions::VT_OUTPUT_TYPE, static_cast<int8_t>(output_type), 0); + } + explicit ArgMinOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ArgMinOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ArgMinOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ArgMinOptions> +CreateArgMinOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::TensorType output_type = onert_tflite::TensorType_FLOAT32) +{ + ArgMinOptionsBuilder builder_(_fbb); + builder_.add_output_type(output_type); + return builder_.Finish(); +} + +struct GreaterOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef GreaterOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct GreaterOptionsBuilder +{ + typedef GreaterOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit GreaterOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<GreaterOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<GreaterOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<GreaterOptions> +CreateGreaterOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + GreaterOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct GreaterEqualOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef GreaterEqualOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct GreaterEqualOptionsBuilder +{ + typedef GreaterEqualOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit GreaterEqualOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<GreaterEqualOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<GreaterEqualOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<GreaterEqualOptions> +CreateGreaterEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + GreaterEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct LessOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef LessOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct LessOptionsBuilder +{ + typedef LessOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit LessOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<LessOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<LessOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<LessOptions> CreateLessOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + LessOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct LessEqualOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef LessEqualOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct LessEqualOptionsBuilder +{ + typedef LessEqualOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit LessEqualOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<LessEqualOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<LessEqualOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<LessEqualOptions> +CreateLessEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + LessEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct NegOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef NegOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct NegOptionsBuilder +{ + typedef NegOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit NegOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<NegOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<NegOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<NegOptions> CreateNegOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + NegOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct SelectOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SelectOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct SelectOptionsBuilder +{ + typedef SelectOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SelectOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SelectOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SelectOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SelectOptions> +CreateSelectOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + SelectOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct SliceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SliceOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct SliceOptionsBuilder +{ + typedef SliceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SliceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SliceOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SliceOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SliceOptions> +CreateSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + SliceOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct TransposeConvOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef TransposeConvOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8 + }; + onert_tflite::Padding padding() const + { + return static_cast<onert_tflite::Padding>(GetField<int8_t>(VT_PADDING, 0)); + } + int32_t stride_w() const { return GetField<int32_t>(VT_STRIDE_W, 0); } + int32_t stride_h() const { return GetField<int32_t>(VT_STRIDE_H, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_PADDING, 1) && + VerifyField<int32_t>(verifier, VT_STRIDE_W, 4) && + VerifyField<int32_t>(verifier, VT_STRIDE_H, 4) && verifier.EndTable(); + } +}; + +struct TransposeConvOptionsBuilder +{ + typedef TransposeConvOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_padding(onert_tflite::Padding padding) + { + fbb_.AddElement<int8_t>(TransposeConvOptions::VT_PADDING, static_cast<int8_t>(padding), 0); + } + void add_stride_w(int32_t stride_w) + { + fbb_.AddElement<int32_t>(TransposeConvOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) + { + fbb_.AddElement<int32_t>(TransposeConvOptions::VT_STRIDE_H, stride_h, 0); + } + explicit TransposeConvOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<TransposeConvOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<TransposeConvOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<TransposeConvOptions> +CreateTransposeConvOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::Padding padding = onert_tflite::Padding_SAME, + int32_t stride_w = 0, int32_t stride_h = 0) +{ + TransposeConvOptionsBuilder builder_(_fbb); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_padding(padding); + return builder_.Finish(); +} + +struct ExpandDimsOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ExpandDimsOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct ExpandDimsOptionsBuilder +{ + typedef ExpandDimsOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ExpandDimsOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ExpandDimsOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ExpandDimsOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ExpandDimsOptions> +CreateExpandDimsOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + ExpandDimsOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct SparseToDenseOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SparseToDenseOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_VALIDATE_INDICES = 4 + }; + bool validate_indices() const { return GetField<uint8_t>(VT_VALIDATE_INDICES, 0) != 0; } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_VALIDATE_INDICES, 1) && + verifier.EndTable(); + } +}; + +struct SparseToDenseOptionsBuilder +{ + typedef SparseToDenseOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_validate_indices(bool validate_indices) + { + fbb_.AddElement<uint8_t>(SparseToDenseOptions::VT_VALIDATE_INDICES, + static_cast<uint8_t>(validate_indices), 0); + } + explicit SparseToDenseOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SparseToDenseOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SparseToDenseOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SparseToDenseOptions> +CreateSparseToDenseOptions(::flatbuffers::FlatBufferBuilder &_fbb, bool validate_indices = false) +{ + SparseToDenseOptionsBuilder builder_(_fbb); + builder_.add_validate_indices(validate_indices); + return builder_.Finish(); +} + +struct EqualOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef EqualOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct EqualOptionsBuilder +{ + typedef EqualOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit EqualOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<EqualOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<EqualOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<EqualOptions> +CreateEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + EqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct NotEqualOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef NotEqualOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct NotEqualOptionsBuilder +{ + typedef NotEqualOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit NotEqualOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<NotEqualOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<NotEqualOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<NotEqualOptions> +CreateNotEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + NotEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct ShapeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ShapeOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_OUT_TYPE = 4 + }; + onert_tflite::TensorType out_type() const + { + return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_OUT_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_OUT_TYPE, 1) && + verifier.EndTable(); + } +}; + +struct ShapeOptionsBuilder +{ + typedef ShapeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_out_type(onert_tflite::TensorType out_type) + { + fbb_.AddElement<int8_t>(ShapeOptions::VT_OUT_TYPE, static_cast<int8_t>(out_type), 0); + } + explicit ShapeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ShapeOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ShapeOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ShapeOptions> +CreateShapeOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::TensorType out_type = onert_tflite::TensorType_FLOAT32) +{ + ShapeOptionsBuilder builder_(_fbb); + builder_.add_out_type(out_type); + return builder_.Finish(); +} + +struct RankOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef RankOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct RankOptionsBuilder +{ + typedef RankOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit RankOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<RankOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<RankOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<RankOptions> CreateRankOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + RankOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct PowOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef PowOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct PowOptionsBuilder +{ + typedef PowOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit PowOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<PowOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<PowOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<PowOptions> CreatePowOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + PowOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct FakeQuantOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef FakeQuantOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_MIN = 4, + VT_MAX = 6, + VT_NUM_BITS = 8, + VT_NARROW_RANGE = 10 + }; + float min() const { return GetField<float>(VT_MIN, 0.0f); } + float max() const { return GetField<float>(VT_MAX, 0.0f); } + int32_t num_bits() const { return GetField<int32_t>(VT_NUM_BITS, 0); } + bool narrow_range() const { return GetField<uint8_t>(VT_NARROW_RANGE, 0) != 0; } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<float>(verifier, VT_MIN, 4) && + VerifyField<float>(verifier, VT_MAX, 4) && + VerifyField<int32_t>(verifier, VT_NUM_BITS, 4) && + VerifyField<uint8_t>(verifier, VT_NARROW_RANGE, 1) && verifier.EndTable(); + } +}; + +struct FakeQuantOptionsBuilder +{ + typedef FakeQuantOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_min(float min) { fbb_.AddElement<float>(FakeQuantOptions::VT_MIN, min, 0.0f); } + void add_max(float max) { fbb_.AddElement<float>(FakeQuantOptions::VT_MAX, max, 0.0f); } + void add_num_bits(int32_t num_bits) + { + fbb_.AddElement<int32_t>(FakeQuantOptions::VT_NUM_BITS, num_bits, 0); + } + void add_narrow_range(bool narrow_range) + { + fbb_.AddElement<uint8_t>(FakeQuantOptions::VT_NARROW_RANGE, static_cast<uint8_t>(narrow_range), + 0); + } + explicit FakeQuantOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<FakeQuantOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<FakeQuantOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<FakeQuantOptions> +CreateFakeQuantOptions(::flatbuffers::FlatBufferBuilder &_fbb, float min = 0.0f, float max = 0.0f, + int32_t num_bits = 0, bool narrow_range = false) +{ + FakeQuantOptionsBuilder builder_(_fbb); + builder_.add_num_bits(num_bits); + builder_.add_max(max); + builder_.add_min(min); + builder_.add_narrow_range(narrow_range); + return builder_.Finish(); +} + +struct PackOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef PackOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_VALUES_COUNT = 4, + VT_AXIS = 6 + }; + int32_t values_count() const { return GetField<int32_t>(VT_VALUES_COUNT, 0); } + int32_t axis() const { return GetField<int32_t>(VT_AXIS, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_VALUES_COUNT, 4) && + VerifyField<int32_t>(verifier, VT_AXIS, 4) && verifier.EndTable(); + } +}; + +struct PackOptionsBuilder +{ + typedef PackOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_values_count(int32_t values_count) + { + fbb_.AddElement<int32_t>(PackOptions::VT_VALUES_COUNT, values_count, 0); + } + void add_axis(int32_t axis) { fbb_.AddElement<int32_t>(PackOptions::VT_AXIS, axis, 0); } + explicit PackOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<PackOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<PackOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<PackOptions> CreatePackOptions(::flatbuffers::FlatBufferBuilder &_fbb, + int32_t values_count = 0, + int32_t axis = 0) +{ + PackOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + builder_.add_values_count(values_count); + return builder_.Finish(); +} + +struct LogicalOrOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef LogicalOrOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct LogicalOrOptionsBuilder +{ + typedef LogicalOrOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit LogicalOrOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<LogicalOrOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<LogicalOrOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<LogicalOrOptions> +CreateLogicalOrOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + LogicalOrOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct OneHotOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef OneHotOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_AXIS = 4 + }; + int32_t axis() const { return GetField<int32_t>(VT_AXIS, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_AXIS, 4) && + verifier.EndTable(); + } +}; + +struct OneHotOptionsBuilder +{ + typedef OneHotOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { fbb_.AddElement<int32_t>(OneHotOptions::VT_AXIS, axis, 0); } + explicit OneHotOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<OneHotOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<OneHotOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<OneHotOptions> +CreateOneHotOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t axis = 0) +{ + OneHotOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + return builder_.Finish(); +} + +struct AbsOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef AbsOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct AbsOptionsBuilder +{ + typedef AbsOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit AbsOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<AbsOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<AbsOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<AbsOptions> CreateAbsOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + AbsOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct HardSwishOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef HardSwishOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct HardSwishOptionsBuilder +{ + typedef HardSwishOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit HardSwishOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<HardSwishOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<HardSwishOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<HardSwishOptions> +CreateHardSwishOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + HardSwishOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct LogicalAndOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef LogicalAndOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct LogicalAndOptionsBuilder +{ + typedef LogicalAndOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit LogicalAndOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<LogicalAndOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<LogicalAndOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<LogicalAndOptions> +CreateLogicalAndOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + LogicalAndOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct LogicalNotOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef LogicalNotOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct LogicalNotOptionsBuilder +{ + typedef LogicalNotOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit LogicalNotOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<LogicalNotOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<LogicalNotOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<LogicalNotOptions> +CreateLogicalNotOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + LogicalNotOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct UnpackOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef UnpackOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_NUM = 4, + VT_AXIS = 6 + }; + int32_t num() const { return GetField<int32_t>(VT_NUM, 0); } + int32_t axis() const { return GetField<int32_t>(VT_AXIS, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_NUM, 4) && + VerifyField<int32_t>(verifier, VT_AXIS, 4) && verifier.EndTable(); + } +}; + +struct UnpackOptionsBuilder +{ + typedef UnpackOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_num(int32_t num) { fbb_.AddElement<int32_t>(UnpackOptions::VT_NUM, num, 0); } + void add_axis(int32_t axis) { fbb_.AddElement<int32_t>(UnpackOptions::VT_AXIS, axis, 0); } + explicit UnpackOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<UnpackOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<UnpackOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<UnpackOptions> +CreateUnpackOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t num = 0, int32_t axis = 0) +{ + UnpackOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + builder_.add_num(num); + return builder_.Finish(); +} + +struct FloorDivOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef FloorDivOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct FloorDivOptionsBuilder +{ + typedef FloorDivOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit FloorDivOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<FloorDivOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<FloorDivOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<FloorDivOptions> +CreateFloorDivOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + FloorDivOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct SquareOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SquareOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct SquareOptionsBuilder +{ + typedef SquareOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SquareOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SquareOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SquareOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SquareOptions> +CreateSquareOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + SquareOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct ZerosLikeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ZerosLikeOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct ZerosLikeOptionsBuilder +{ + typedef ZerosLikeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ZerosLikeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ZerosLikeOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ZerosLikeOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ZerosLikeOptions> +CreateZerosLikeOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + ZerosLikeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct FillOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef FillOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct FillOptionsBuilder +{ + typedef FillOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit FillOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<FillOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<FillOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<FillOptions> CreateFillOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + FillOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct FloorModOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef FloorModOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct FloorModOptionsBuilder +{ + typedef FloorModOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit FloorModOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<FloorModOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<FloorModOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<FloorModOptions> +CreateFloorModOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + FloorModOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct RangeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef RangeOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct RangeOptionsBuilder +{ + typedef RangeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit RangeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<RangeOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<RangeOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<RangeOptions> +CreateRangeOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + RangeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct LeakyReluOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef LeakyReluOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_ALPHA = 4 + }; + float alpha() const { return GetField<float>(VT_ALPHA, 0.0f); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<float>(verifier, VT_ALPHA, 4) && + verifier.EndTable(); + } +}; + +struct LeakyReluOptionsBuilder +{ + typedef LeakyReluOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_alpha(float alpha) { fbb_.AddElement<float>(LeakyReluOptions::VT_ALPHA, alpha, 0.0f); } + explicit LeakyReluOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<LeakyReluOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<LeakyReluOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<LeakyReluOptions> +CreateLeakyReluOptions(::flatbuffers::FlatBufferBuilder &_fbb, float alpha = 0.0f) +{ + LeakyReluOptionsBuilder builder_(_fbb); + builder_.add_alpha(alpha); + return builder_.Finish(); +} + +struct SquaredDifferenceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SquaredDifferenceOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct SquaredDifferenceOptionsBuilder +{ + typedef SquaredDifferenceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SquaredDifferenceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SquaredDifferenceOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SquaredDifferenceOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SquaredDifferenceOptions> +CreateSquaredDifferenceOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + SquaredDifferenceOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct MirrorPadOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef MirrorPadOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_MODE = 4 + }; + onert_tflite::MirrorPadMode mode() const + { + return static_cast<onert_tflite::MirrorPadMode>(GetField<int8_t>(VT_MODE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_MODE, 1) && + verifier.EndTable(); + } +}; + +struct MirrorPadOptionsBuilder +{ + typedef MirrorPadOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_mode(onert_tflite::MirrorPadMode mode) + { + fbb_.AddElement<int8_t>(MirrorPadOptions::VT_MODE, static_cast<int8_t>(mode), 0); + } + explicit MirrorPadOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<MirrorPadOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<MirrorPadOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<MirrorPadOptions> +CreateMirrorPadOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::MirrorPadMode mode = onert_tflite::MirrorPadMode_REFLECT) +{ + MirrorPadOptionsBuilder builder_(_fbb); + builder_.add_mode(mode); + return builder_.Finish(); +} + +struct UniqueOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef UniqueOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_IDX_OUT_TYPE = 4 + }; + onert_tflite::TensorType idx_out_type() const + { + return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_IDX_OUT_TYPE, 2)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_IDX_OUT_TYPE, 1) && + verifier.EndTable(); + } +}; + +struct UniqueOptionsBuilder +{ + typedef UniqueOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_idx_out_type(onert_tflite::TensorType idx_out_type) + { + fbb_.AddElement<int8_t>(UniqueOptions::VT_IDX_OUT_TYPE, static_cast<int8_t>(idx_out_type), 2); + } + explicit UniqueOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<UniqueOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<UniqueOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<UniqueOptions> +CreateUniqueOptions(::flatbuffers::FlatBufferBuilder &_fbb, + onert_tflite::TensorType idx_out_type = onert_tflite::TensorType_INT32) +{ + UniqueOptionsBuilder builder_(_fbb); + builder_.add_idx_out_type(idx_out_type); + return builder_.Finish(); +} + +struct ReverseV2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ReverseV2OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct ReverseV2OptionsBuilder +{ + typedef ReverseV2Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ReverseV2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ReverseV2Options> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ReverseV2Options>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ReverseV2Options> +CreateReverseV2Options(::flatbuffers::FlatBufferBuilder &_fbb) +{ + ReverseV2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct AddNOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef AddNOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct AddNOptionsBuilder +{ + typedef AddNOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit AddNOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<AddNOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<AddNOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<AddNOptions> CreateAddNOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + AddNOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct GatherNdOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef GatherNdOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct GatherNdOptionsBuilder +{ + typedef GatherNdOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit GatherNdOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<GatherNdOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<GatherNdOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<GatherNdOptions> +CreateGatherNdOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + GatherNdOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct WhereOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef WhereOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct WhereOptionsBuilder +{ + typedef WhereOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit WhereOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<WhereOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<WhereOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<WhereOptions> +CreateWhereOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + WhereOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct ReverseSequenceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ReverseSequenceOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_SEQ_DIM = 4, + VT_BATCH_DIM = 6 + }; + int32_t seq_dim() const { return GetField<int32_t>(VT_SEQ_DIM, 0); } + int32_t batch_dim() const { return GetField<int32_t>(VT_BATCH_DIM, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_SEQ_DIM, 4) && + VerifyField<int32_t>(verifier, VT_BATCH_DIM, 4) && verifier.EndTable(); + } +}; + +struct ReverseSequenceOptionsBuilder +{ + typedef ReverseSequenceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_seq_dim(int32_t seq_dim) + { + fbb_.AddElement<int32_t>(ReverseSequenceOptions::VT_SEQ_DIM, seq_dim, 0); + } + void add_batch_dim(int32_t batch_dim) + { + fbb_.AddElement<int32_t>(ReverseSequenceOptions::VT_BATCH_DIM, batch_dim, 0); + } + explicit ReverseSequenceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ReverseSequenceOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ReverseSequenceOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ReverseSequenceOptions> +CreateReverseSequenceOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t seq_dim = 0, + int32_t batch_dim = 0) +{ + ReverseSequenceOptionsBuilder builder_(_fbb); + builder_.add_batch_dim(batch_dim); + builder_.add_seq_dim(seq_dim); + return builder_.Finish(); +} + +struct MatrixDiagOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef MatrixDiagOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct MatrixDiagOptionsBuilder +{ + typedef MatrixDiagOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit MatrixDiagOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<MatrixDiagOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<MatrixDiagOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<MatrixDiagOptions> +CreateMatrixDiagOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + MatrixDiagOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct QuantizeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef QuantizeOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct QuantizeOptionsBuilder +{ + typedef QuantizeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit QuantizeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<QuantizeOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<QuantizeOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<QuantizeOptions> +CreateQuantizeOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + QuantizeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct MatrixSetDiagOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef MatrixSetDiagOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct MatrixSetDiagOptionsBuilder +{ + typedef MatrixSetDiagOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit MatrixSetDiagOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<MatrixSetDiagOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<MatrixSetDiagOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<MatrixSetDiagOptions> +CreateMatrixSetDiagOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + MatrixSetDiagOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct IfOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef IfOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_THEN_SUBGRAPH_INDEX = 4, + VT_ELSE_SUBGRAPH_INDEX = 6 + }; + int32_t then_subgraph_index() const { return GetField<int32_t>(VT_THEN_SUBGRAPH_INDEX, 0); } + int32_t else_subgraph_index() const { return GetField<int32_t>(VT_ELSE_SUBGRAPH_INDEX, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_THEN_SUBGRAPH_INDEX, 4) && + VerifyField<int32_t>(verifier, VT_ELSE_SUBGRAPH_INDEX, 4) && verifier.EndTable(); + } +}; + +struct IfOptionsBuilder +{ + typedef IfOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_then_subgraph_index(int32_t then_subgraph_index) + { + fbb_.AddElement<int32_t>(IfOptions::VT_THEN_SUBGRAPH_INDEX, then_subgraph_index, 0); + } + void add_else_subgraph_index(int32_t else_subgraph_index) + { + fbb_.AddElement<int32_t>(IfOptions::VT_ELSE_SUBGRAPH_INDEX, else_subgraph_index, 0); + } + explicit IfOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<IfOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<IfOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<IfOptions> CreateIfOptions(::flatbuffers::FlatBufferBuilder &_fbb, + int32_t then_subgraph_index = 0, + int32_t else_subgraph_index = 0) +{ + IfOptionsBuilder builder_(_fbb); + builder_.add_else_subgraph_index(else_subgraph_index); + builder_.add_then_subgraph_index(then_subgraph_index); + return builder_.Finish(); +} + +struct CallOnceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef CallOnceOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_INIT_SUBGRAPH_INDEX = 4 + }; + int32_t init_subgraph_index() const { return GetField<int32_t>(VT_INIT_SUBGRAPH_INDEX, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_INIT_SUBGRAPH_INDEX, 4) && verifier.EndTable(); + } +}; + +struct CallOnceOptionsBuilder +{ + typedef CallOnceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_init_subgraph_index(int32_t init_subgraph_index) + { + fbb_.AddElement<int32_t>(CallOnceOptions::VT_INIT_SUBGRAPH_INDEX, init_subgraph_index, 0); + } + explicit CallOnceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<CallOnceOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<CallOnceOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<CallOnceOptions> +CreateCallOnceOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t init_subgraph_index = 0) +{ + CallOnceOptionsBuilder builder_(_fbb); + builder_.add_init_subgraph_index(init_subgraph_index); + return builder_.Finish(); +} + +struct WhileOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef WhileOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_COND_SUBGRAPH_INDEX = 4, + VT_BODY_SUBGRAPH_INDEX = 6 + }; + int32_t cond_subgraph_index() const { return GetField<int32_t>(VT_COND_SUBGRAPH_INDEX, 0); } + int32_t body_subgraph_index() const { return GetField<int32_t>(VT_BODY_SUBGRAPH_INDEX, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_COND_SUBGRAPH_INDEX, 4) && + VerifyField<int32_t>(verifier, VT_BODY_SUBGRAPH_INDEX, 4) && verifier.EndTable(); + } +}; + +struct WhileOptionsBuilder +{ + typedef WhileOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_cond_subgraph_index(int32_t cond_subgraph_index) + { + fbb_.AddElement<int32_t>(WhileOptions::VT_COND_SUBGRAPH_INDEX, cond_subgraph_index, 0); + } + void add_body_subgraph_index(int32_t body_subgraph_index) + { + fbb_.AddElement<int32_t>(WhileOptions::VT_BODY_SUBGRAPH_INDEX, body_subgraph_index, 0); + } + explicit WhileOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<WhileOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<WhileOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<WhileOptions> +CreateWhileOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t cond_subgraph_index = 0, + int32_t body_subgraph_index = 0) +{ + WhileOptionsBuilder builder_(_fbb); + builder_.add_body_subgraph_index(body_subgraph_index); + builder_.add_cond_subgraph_index(cond_subgraph_index); + return builder_.Finish(); +} + +struct NonMaxSuppressionV4Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef NonMaxSuppressionV4OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct NonMaxSuppressionV4OptionsBuilder +{ + typedef NonMaxSuppressionV4Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit NonMaxSuppressionV4OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<NonMaxSuppressionV4Options> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<NonMaxSuppressionV4Options>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<NonMaxSuppressionV4Options> +CreateNonMaxSuppressionV4Options(::flatbuffers::FlatBufferBuilder &_fbb) +{ + NonMaxSuppressionV4OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct NonMaxSuppressionV5Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef NonMaxSuppressionV5OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct NonMaxSuppressionV5OptionsBuilder +{ + typedef NonMaxSuppressionV5Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit NonMaxSuppressionV5OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<NonMaxSuppressionV5Options> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<NonMaxSuppressionV5Options>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<NonMaxSuppressionV5Options> +CreateNonMaxSuppressionV5Options(::flatbuffers::FlatBufferBuilder &_fbb) +{ + NonMaxSuppressionV5OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct ScatterNdOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ScatterNdOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct ScatterNdOptionsBuilder +{ + typedef ScatterNdOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ScatterNdOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ScatterNdOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ScatterNdOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ScatterNdOptions> +CreateScatterNdOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + ScatterNdOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct SelectV2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SelectV2OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct SelectV2OptionsBuilder +{ + typedef SelectV2Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SelectV2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SelectV2Options> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SelectV2Options>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SelectV2Options> +CreateSelectV2Options(::flatbuffers::FlatBufferBuilder &_fbb) +{ + SelectV2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct DensifyOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef DensifyOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct DensifyOptionsBuilder +{ + typedef DensifyOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit DensifyOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<DensifyOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<DensifyOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<DensifyOptions> +CreateDensifyOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + DensifyOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct SegmentSumOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SegmentSumOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct SegmentSumOptionsBuilder +{ + typedef SegmentSumOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SegmentSumOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SegmentSumOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SegmentSumOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SegmentSumOptions> +CreateSegmentSumOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + SegmentSumOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct BatchMatMulOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef BatchMatMulOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_ADJ_X = 4, + VT_ADJ_Y = 6, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 8 + }; + bool adj_x() const { return GetField<uint8_t>(VT_ADJ_X, 0) != 0; } + bool adj_y() const { return GetField<uint8_t>(VT_ADJ_Y, 0) != 0; } + bool asymmetric_quantize_inputs() const + { + return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_ADJ_X, 1) && + VerifyField<uint8_t>(verifier, VT_ADJ_Y, 1) && + VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable(); + } +}; + +struct BatchMatMulOptionsBuilder +{ + typedef BatchMatMulOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_adj_x(bool adj_x) + { + fbb_.AddElement<uint8_t>(BatchMatMulOptions::VT_ADJ_X, static_cast<uint8_t>(adj_x), 0); + } + void add_adj_y(bool adj_y) + { + fbb_.AddElement<uint8_t>(BatchMatMulOptions::VT_ADJ_Y, static_cast<uint8_t>(adj_y), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) + { + fbb_.AddElement<uint8_t>(BatchMatMulOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, + static_cast<uint8_t>(asymmetric_quantize_inputs), 0); + } + explicit BatchMatMulOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<BatchMatMulOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<BatchMatMulOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<BatchMatMulOptions> +CreateBatchMatMulOptions(::flatbuffers::FlatBufferBuilder &_fbb, bool adj_x = false, + bool adj_y = false, bool asymmetric_quantize_inputs = false) +{ + BatchMatMulOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_adj_y(adj_y); + builder_.add_adj_x(adj_x); + return builder_.Finish(); +} + +struct CumsumOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef CumsumOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_EXCLUSIVE = 4, + VT_REVERSE = 6 + }; + bool exclusive() const { return GetField<uint8_t>(VT_EXCLUSIVE, 0) != 0; } + bool reverse() const { return GetField<uint8_t>(VT_REVERSE, 0) != 0; } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_EXCLUSIVE, 1) && + VerifyField<uint8_t>(verifier, VT_REVERSE, 1) && verifier.EndTable(); + } +}; + +struct CumsumOptionsBuilder +{ + typedef CumsumOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_exclusive(bool exclusive) + { + fbb_.AddElement<uint8_t>(CumsumOptions::VT_EXCLUSIVE, static_cast<uint8_t>(exclusive), 0); + } + void add_reverse(bool reverse) + { + fbb_.AddElement<uint8_t>(CumsumOptions::VT_REVERSE, static_cast<uint8_t>(reverse), 0); + } + explicit CumsumOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<CumsumOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<CumsumOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<CumsumOptions> +CreateCumsumOptions(::flatbuffers::FlatBufferBuilder &_fbb, bool exclusive = false, + bool reverse = false) +{ + CumsumOptionsBuilder builder_(_fbb); + builder_.add_reverse(reverse); + builder_.add_exclusive(exclusive); + return builder_.Finish(); +} + +struct BroadcastToOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef BroadcastToOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct BroadcastToOptionsBuilder +{ + typedef BroadcastToOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit BroadcastToOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<BroadcastToOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<BroadcastToOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<BroadcastToOptions> +CreateBroadcastToOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + BroadcastToOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct Rfft2dOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef Rfft2dOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct Rfft2dOptionsBuilder +{ + typedef Rfft2dOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit Rfft2dOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<Rfft2dOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<Rfft2dOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<Rfft2dOptions> +CreateRfft2dOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + Rfft2dOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct HashtableOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef HashtableOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_TABLE_ID = 4, + VT_KEY_DTYPE = 6, + VT_VALUE_DTYPE = 8 + }; + int32_t table_id() const { return GetField<int32_t>(VT_TABLE_ID, 0); } + onert_tflite::TensorType key_dtype() const + { + return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_KEY_DTYPE, 0)); + } + onert_tflite::TensorType value_dtype() const + { + return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_VALUE_DTYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_TABLE_ID, 4) && + VerifyField<int8_t>(verifier, VT_KEY_DTYPE, 1) && + VerifyField<int8_t>(verifier, VT_VALUE_DTYPE, 1) && verifier.EndTable(); + } +}; + +struct HashtableOptionsBuilder +{ + typedef HashtableOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_table_id(int32_t table_id) + { + fbb_.AddElement<int32_t>(HashtableOptions::VT_TABLE_ID, table_id, 0); + } + void add_key_dtype(onert_tflite::TensorType key_dtype) + { + fbb_.AddElement<int8_t>(HashtableOptions::VT_KEY_DTYPE, static_cast<int8_t>(key_dtype), 0); + } + void add_value_dtype(onert_tflite::TensorType value_dtype) + { + fbb_.AddElement<int8_t>(HashtableOptions::VT_VALUE_DTYPE, static_cast<int8_t>(value_dtype), 0); + } + explicit HashtableOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<HashtableOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<HashtableOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<HashtableOptions> +CreateHashtableOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t table_id = 0, + onert_tflite::TensorType key_dtype = onert_tflite::TensorType_FLOAT32, + onert_tflite::TensorType value_dtype = onert_tflite::TensorType_FLOAT32) +{ + HashtableOptionsBuilder builder_(_fbb); + builder_.add_table_id(table_id); + builder_.add_value_dtype(value_dtype); + builder_.add_key_dtype(key_dtype); + return builder_.Finish(); +} + +struct HashtableFindOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef HashtableFindOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct HashtableFindOptionsBuilder +{ + typedef HashtableFindOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit HashtableFindOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<HashtableFindOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<HashtableFindOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<HashtableFindOptions> +CreateHashtableFindOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + HashtableFindOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct HashtableImportOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef HashtableImportOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct HashtableImportOptionsBuilder +{ + typedef HashtableImportOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit HashtableImportOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<HashtableImportOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<HashtableImportOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<HashtableImportOptions> +CreateHashtableImportOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + HashtableImportOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct HashtableSizeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef HashtableSizeOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct HashtableSizeOptionsBuilder +{ + typedef HashtableSizeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit HashtableSizeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<HashtableSizeOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<HashtableSizeOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<HashtableSizeOptions> +CreateHashtableSizeOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + HashtableSizeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct VarHandleOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef VarHandleOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_CONTAINER = 4, + VT_SHARED_NAME = 6 + }; + const ::flatbuffers::String *container() const + { + return GetPointer<const ::flatbuffers::String *>(VT_CONTAINER); + } + const ::flatbuffers::String *shared_name() const + { + return GetPointer<const ::flatbuffers::String *>(VT_SHARED_NAME); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_CONTAINER) && + verifier.VerifyString(container()) && VerifyOffset(verifier, VT_SHARED_NAME) && + verifier.VerifyString(shared_name()) && verifier.EndTable(); + } +}; + +struct VarHandleOptionsBuilder +{ + typedef VarHandleOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_container(::flatbuffers::Offset<::flatbuffers::String> container) + { + fbb_.AddOffset(VarHandleOptions::VT_CONTAINER, container); + } + void add_shared_name(::flatbuffers::Offset<::flatbuffers::String> shared_name) + { + fbb_.AddOffset(VarHandleOptions::VT_SHARED_NAME, shared_name); + } + explicit VarHandleOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<VarHandleOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<VarHandleOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<VarHandleOptions> +CreateVarHandleOptions(::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> container = 0, + ::flatbuffers::Offset<::flatbuffers::String> shared_name = 0) +{ + VarHandleOptionsBuilder builder_(_fbb); + builder_.add_shared_name(shared_name); + builder_.add_container(container); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<VarHandleOptions> +CreateVarHandleOptionsDirect(::flatbuffers::FlatBufferBuilder &_fbb, + const char *container = nullptr, const char *shared_name = nullptr) +{ + auto container__ = container ? _fbb.CreateString(container) : 0; + auto shared_name__ = shared_name ? _fbb.CreateString(shared_name) : 0; + return onert_tflite::CreateVarHandleOptions(_fbb, container__, shared_name__); +} + +struct ReadVariableOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ReadVariableOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct ReadVariableOptionsBuilder +{ + typedef ReadVariableOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ReadVariableOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ReadVariableOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ReadVariableOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ReadVariableOptions> +CreateReadVariableOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + ReadVariableOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct AssignVariableOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef AssignVariableOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct AssignVariableOptionsBuilder +{ + typedef AssignVariableOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit AssignVariableOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<AssignVariableOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<AssignVariableOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<AssignVariableOptions> +CreateAssignVariableOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + AssignVariableOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct RandomOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef RandomOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_SEED = 4, + VT_SEED2 = 6 + }; + int64_t seed() const { return GetField<int64_t>(VT_SEED, 0); } + int64_t seed2() const { return GetField<int64_t>(VT_SEED2, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<int64_t>(verifier, VT_SEED, 8) && + VerifyField<int64_t>(verifier, VT_SEED2, 8) && verifier.EndTable(); + } +}; + +struct RandomOptionsBuilder +{ + typedef RandomOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_seed(int64_t seed) { fbb_.AddElement<int64_t>(RandomOptions::VT_SEED, seed, 0); } + void add_seed2(int64_t seed2) { fbb_.AddElement<int64_t>(RandomOptions::VT_SEED2, seed2, 0); } + explicit RandomOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<RandomOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<RandomOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<RandomOptions> +CreateRandomOptions(::flatbuffers::FlatBufferBuilder &_fbb, int64_t seed = 0, int64_t seed2 = 0) +{ + RandomOptionsBuilder builder_(_fbb); + builder_.add_seed2(seed2); + builder_.add_seed(seed); + return builder_.Finish(); +} + +struct BucketizeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef BucketizeOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_BOUNDARIES = 4 + }; + const ::flatbuffers::Vector<float> *boundaries() const + { + return GetPointer<const ::flatbuffers::Vector<float> *>(VT_BOUNDARIES); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_BOUNDARIES) && + verifier.VerifyVector(boundaries()) && verifier.EndTable(); + } +}; + +struct BucketizeOptionsBuilder +{ + typedef BucketizeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_boundaries(::flatbuffers::Offset<::flatbuffers::Vector<float>> boundaries) + { + fbb_.AddOffset(BucketizeOptions::VT_BOUNDARIES, boundaries); + } + explicit BucketizeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<BucketizeOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<BucketizeOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<BucketizeOptions> +CreateBucketizeOptions(::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<float>> boundaries = 0) +{ + BucketizeOptionsBuilder builder_(_fbb); + builder_.add_boundaries(boundaries); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<BucketizeOptions> +CreateBucketizeOptionsDirect(::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<float> *boundaries = nullptr) +{ + auto boundaries__ = boundaries ? _fbb.CreateVector<float>(*boundaries) : 0; + return onert_tflite::CreateBucketizeOptions(_fbb, boundaries__); +} + +struct GeluOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef GeluOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_APPROXIMATE = 4 + }; + bool approximate() const { return GetField<uint8_t>(VT_APPROXIMATE, 0) != 0; } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_APPROXIMATE, 1) && + verifier.EndTable(); + } +}; + +struct GeluOptionsBuilder +{ + typedef GeluOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_approximate(bool approximate) + { + fbb_.AddElement<uint8_t>(GeluOptions::VT_APPROXIMATE, static_cast<uint8_t>(approximate), 0); + } + explicit GeluOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<GeluOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<GeluOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<GeluOptions> CreateGeluOptions(::flatbuffers::FlatBufferBuilder &_fbb, + bool approximate = false) +{ + GeluOptionsBuilder builder_(_fbb); + builder_.add_approximate(approximate); + return builder_.Finish(); +} + +struct DynamicUpdateSliceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef DynamicUpdateSliceOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct DynamicUpdateSliceOptionsBuilder +{ + typedef DynamicUpdateSliceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit DynamicUpdateSliceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<DynamicUpdateSliceOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<DynamicUpdateSliceOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<DynamicUpdateSliceOptions> +CreateDynamicUpdateSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + DynamicUpdateSliceOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct UnsortedSegmentProdOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef UnsortedSegmentProdOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct UnsortedSegmentProdOptionsBuilder +{ + typedef UnsortedSegmentProdOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit UnsortedSegmentProdOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<UnsortedSegmentProdOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<UnsortedSegmentProdOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<UnsortedSegmentProdOptions> +CreateUnsortedSegmentProdOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + UnsortedSegmentProdOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct UnsortedSegmentMaxOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef UnsortedSegmentMaxOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct UnsortedSegmentMaxOptionsBuilder +{ + typedef UnsortedSegmentMaxOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit UnsortedSegmentMaxOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<UnsortedSegmentMaxOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<UnsortedSegmentMaxOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<UnsortedSegmentMaxOptions> +CreateUnsortedSegmentMaxOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + UnsortedSegmentMaxOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct UnsortedSegmentSumOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef UnsortedSegmentSumOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct UnsortedSegmentSumOptionsBuilder +{ + typedef UnsortedSegmentSumOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit UnsortedSegmentSumOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<UnsortedSegmentSumOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<UnsortedSegmentSumOptions>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<UnsortedSegmentSumOptions> +CreateUnsortedSegmentSumOptions(::flatbuffers::FlatBufferBuilder &_fbb) +{ + UnsortedSegmentSumOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct ATan2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ATan2OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct ATan2OptionsBuilder +{ + typedef ATan2Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ATan2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<ATan2Options> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<ATan2Options>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<ATan2Options> +CreateATan2Options(::flatbuffers::FlatBufferBuilder &_fbb) +{ + ATan2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct OperatorCode FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef OperatorCodeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_DEPRECATED_BUILTIN_CODE = 4, + VT_CUSTOM_CODE = 6, + VT_VERSION = 8, + VT_BUILTIN_CODE = 10 + }; + int8_t deprecated_builtin_code() const { return GetField<int8_t>(VT_DEPRECATED_BUILTIN_CODE, 0); } + const ::flatbuffers::String *custom_code() const + { + return GetPointer<const ::flatbuffers::String *>(VT_CUSTOM_CODE); + } + int32_t version() const { return GetField<int32_t>(VT_VERSION, 1); } + onert_tflite::BuiltinOperator builtin_code() const + { + return static_cast<onert_tflite::BuiltinOperator>(GetField<int32_t>(VT_BUILTIN_CODE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_DEPRECATED_BUILTIN_CODE, 1) && + VerifyOffset(verifier, VT_CUSTOM_CODE) && verifier.VerifyString(custom_code()) && + VerifyField<int32_t>(verifier, VT_VERSION, 4) && + VerifyField<int32_t>(verifier, VT_BUILTIN_CODE, 4) && verifier.EndTable(); + } +}; + +struct OperatorCodeBuilder +{ + typedef OperatorCode Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_deprecated_builtin_code(int8_t deprecated_builtin_code) + { + fbb_.AddElement<int8_t>(OperatorCode::VT_DEPRECATED_BUILTIN_CODE, deprecated_builtin_code, 0); + } + void add_custom_code(::flatbuffers::Offset<::flatbuffers::String> custom_code) + { + fbb_.AddOffset(OperatorCode::VT_CUSTOM_CODE, custom_code); + } + void add_version(int32_t version) + { + fbb_.AddElement<int32_t>(OperatorCode::VT_VERSION, version, 1); + } + void add_builtin_code(onert_tflite::BuiltinOperator builtin_code) + { + fbb_.AddElement<int32_t>(OperatorCode::VT_BUILTIN_CODE, static_cast<int32_t>(builtin_code), 0); + } + explicit OperatorCodeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<OperatorCode> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<OperatorCode>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<OperatorCode> +CreateOperatorCode(::flatbuffers::FlatBufferBuilder &_fbb, int8_t deprecated_builtin_code = 0, + ::flatbuffers::Offset<::flatbuffers::String> custom_code = 0, + int32_t version = 1, + onert_tflite::BuiltinOperator builtin_code = onert_tflite::BuiltinOperator_ADD) +{ + OperatorCodeBuilder builder_(_fbb); + builder_.add_builtin_code(builtin_code); + builder_.add_version(version); + builder_.add_custom_code(custom_code); + builder_.add_deprecated_builtin_code(deprecated_builtin_code); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<OperatorCode> CreateOperatorCodeDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, int8_t deprecated_builtin_code = 0, + const char *custom_code = nullptr, int32_t version = 1, + onert_tflite::BuiltinOperator builtin_code = onert_tflite::BuiltinOperator_ADD) +{ + auto custom_code__ = custom_code ? _fbb.CreateString(custom_code) : 0; + return onert_tflite::CreateOperatorCode(_fbb, deprecated_builtin_code, custom_code__, version, + builtin_code); +} + +struct Operator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef OperatorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_OPCODE_INDEX = 4, + VT_INPUTS = 6, + VT_OUTPUTS = 8, + VT_BUILTIN_OPTIONS_TYPE = 10, + VT_BUILTIN_OPTIONS = 12, + VT_CUSTOM_OPTIONS = 14, + VT_CUSTOM_OPTIONS_FORMAT = 16, + VT_MUTATING_VARIABLE_INPUTS = 18, + VT_INTERMEDIATES = 20 + }; + uint32_t opcode_index() const { return GetField<uint32_t>(VT_OPCODE_INDEX, 0); } + const ::flatbuffers::Vector<int32_t> *inputs() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_INPUTS); + } + const ::flatbuffers::Vector<int32_t> *outputs() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_OUTPUTS); + } + onert_tflite::BuiltinOptions builtin_options_type() const + { + return static_cast<onert_tflite::BuiltinOptions>(GetField<uint8_t>(VT_BUILTIN_OPTIONS_TYPE, 0)); + } + const void *builtin_options() const { return GetPointer<const void *>(VT_BUILTIN_OPTIONS); } + template <typename T> const T *builtin_options_as() const; + const onert_tflite::Conv2DOptions *builtin_options_as_Conv2DOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_Conv2DOptions + ? static_cast<const onert_tflite::Conv2DOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::DepthwiseConv2DOptions *builtin_options_as_DepthwiseConv2DOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_DepthwiseConv2DOptions + ? static_cast<const onert_tflite::DepthwiseConv2DOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ConcatEmbeddingsOptions *builtin_options_as_ConcatEmbeddingsOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ConcatEmbeddingsOptions + ? static_cast<const onert_tflite::ConcatEmbeddingsOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::LSHProjectionOptions *builtin_options_as_LSHProjectionOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_LSHProjectionOptions + ? static_cast<const onert_tflite::LSHProjectionOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::Pool2DOptions *builtin_options_as_Pool2DOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_Pool2DOptions + ? static_cast<const onert_tflite::Pool2DOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SVDFOptions *builtin_options_as_SVDFOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SVDFOptions + ? static_cast<const onert_tflite::SVDFOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::RNNOptions *builtin_options_as_RNNOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_RNNOptions + ? static_cast<const onert_tflite::RNNOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::FullyConnectedOptions *builtin_options_as_FullyConnectedOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_FullyConnectedOptions + ? static_cast<const onert_tflite::FullyConnectedOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SoftmaxOptions *builtin_options_as_SoftmaxOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SoftmaxOptions + ? static_cast<const onert_tflite::SoftmaxOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ConcatenationOptions *builtin_options_as_ConcatenationOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ConcatenationOptions + ? static_cast<const onert_tflite::ConcatenationOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::AddOptions *builtin_options_as_AddOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_AddOptions + ? static_cast<const onert_tflite::AddOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::L2NormOptions *builtin_options_as_L2NormOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_L2NormOptions + ? static_cast<const onert_tflite::L2NormOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::LocalResponseNormalizationOptions * + builtin_options_as_LocalResponseNormalizationOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_LocalResponseNormalizationOptions + ? static_cast<const onert_tflite::LocalResponseNormalizationOptions *>( + builtin_options()) + : nullptr; + } + const onert_tflite::LSTMOptions *builtin_options_as_LSTMOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_LSTMOptions + ? static_cast<const onert_tflite::LSTMOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ResizeBilinearOptions *builtin_options_as_ResizeBilinearOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ResizeBilinearOptions + ? static_cast<const onert_tflite::ResizeBilinearOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::CallOptions *builtin_options_as_CallOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_CallOptions + ? static_cast<const onert_tflite::CallOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ReshapeOptions *builtin_options_as_ReshapeOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ReshapeOptions + ? static_cast<const onert_tflite::ReshapeOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SkipGramOptions *builtin_options_as_SkipGramOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SkipGramOptions + ? static_cast<const onert_tflite::SkipGramOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SpaceToDepthOptions *builtin_options_as_SpaceToDepthOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SpaceToDepthOptions + ? static_cast<const onert_tflite::SpaceToDepthOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::EmbeddingLookupSparseOptions * + builtin_options_as_EmbeddingLookupSparseOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_EmbeddingLookupSparseOptions + ? static_cast<const onert_tflite::EmbeddingLookupSparseOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::MulOptions *builtin_options_as_MulOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_MulOptions + ? static_cast<const onert_tflite::MulOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::PadOptions *builtin_options_as_PadOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_PadOptions + ? static_cast<const onert_tflite::PadOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::GatherOptions *builtin_options_as_GatherOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_GatherOptions + ? static_cast<const onert_tflite::GatherOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::BatchToSpaceNDOptions *builtin_options_as_BatchToSpaceNDOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_BatchToSpaceNDOptions + ? static_cast<const onert_tflite::BatchToSpaceNDOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SpaceToBatchNDOptions *builtin_options_as_SpaceToBatchNDOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SpaceToBatchNDOptions + ? static_cast<const onert_tflite::SpaceToBatchNDOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::TransposeOptions *builtin_options_as_TransposeOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_TransposeOptions + ? static_cast<const onert_tflite::TransposeOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ReducerOptions *builtin_options_as_ReducerOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ReducerOptions + ? static_cast<const onert_tflite::ReducerOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SubOptions *builtin_options_as_SubOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SubOptions + ? static_cast<const onert_tflite::SubOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::DivOptions *builtin_options_as_DivOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_DivOptions + ? static_cast<const onert_tflite::DivOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SqueezeOptions *builtin_options_as_SqueezeOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SqueezeOptions + ? static_cast<const onert_tflite::SqueezeOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SequenceRNNOptions *builtin_options_as_SequenceRNNOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SequenceRNNOptions + ? static_cast<const onert_tflite::SequenceRNNOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::StridedSliceOptions *builtin_options_as_StridedSliceOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_StridedSliceOptions + ? static_cast<const onert_tflite::StridedSliceOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ExpOptions *builtin_options_as_ExpOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ExpOptions + ? static_cast<const onert_tflite::ExpOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::TopKV2Options *builtin_options_as_TopKV2Options() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_TopKV2Options + ? static_cast<const onert_tflite::TopKV2Options *>(builtin_options()) + : nullptr; + } + const onert_tflite::SplitOptions *builtin_options_as_SplitOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SplitOptions + ? static_cast<const onert_tflite::SplitOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::LogSoftmaxOptions *builtin_options_as_LogSoftmaxOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_LogSoftmaxOptions + ? static_cast<const onert_tflite::LogSoftmaxOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::CastOptions *builtin_options_as_CastOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_CastOptions + ? static_cast<const onert_tflite::CastOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::DequantizeOptions *builtin_options_as_DequantizeOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_DequantizeOptions + ? static_cast<const onert_tflite::DequantizeOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::MaximumMinimumOptions *builtin_options_as_MaximumMinimumOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_MaximumMinimumOptions + ? static_cast<const onert_tflite::MaximumMinimumOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ArgMaxOptions *builtin_options_as_ArgMaxOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ArgMaxOptions + ? static_cast<const onert_tflite::ArgMaxOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::LessOptions *builtin_options_as_LessOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_LessOptions + ? static_cast<const onert_tflite::LessOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::NegOptions *builtin_options_as_NegOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_NegOptions + ? static_cast<const onert_tflite::NegOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::PadV2Options *builtin_options_as_PadV2Options() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_PadV2Options + ? static_cast<const onert_tflite::PadV2Options *>(builtin_options()) + : nullptr; + } + const onert_tflite::GreaterOptions *builtin_options_as_GreaterOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_GreaterOptions + ? static_cast<const onert_tflite::GreaterOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::GreaterEqualOptions *builtin_options_as_GreaterEqualOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_GreaterEqualOptions + ? static_cast<const onert_tflite::GreaterEqualOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::LessEqualOptions *builtin_options_as_LessEqualOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_LessEqualOptions + ? static_cast<const onert_tflite::LessEqualOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SelectOptions *builtin_options_as_SelectOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SelectOptions + ? static_cast<const onert_tflite::SelectOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SliceOptions *builtin_options_as_SliceOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SliceOptions + ? static_cast<const onert_tflite::SliceOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::TransposeConvOptions *builtin_options_as_TransposeConvOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_TransposeConvOptions + ? static_cast<const onert_tflite::TransposeConvOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SparseToDenseOptions *builtin_options_as_SparseToDenseOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SparseToDenseOptions + ? static_cast<const onert_tflite::SparseToDenseOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::TileOptions *builtin_options_as_TileOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_TileOptions + ? static_cast<const onert_tflite::TileOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ExpandDimsOptions + ? static_cast<const onert_tflite::ExpandDimsOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::EqualOptions *builtin_options_as_EqualOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_EqualOptions + ? static_cast<const onert_tflite::EqualOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::NotEqualOptions *builtin_options_as_NotEqualOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_NotEqualOptions + ? static_cast<const onert_tflite::NotEqualOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ShapeOptions *builtin_options_as_ShapeOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ShapeOptions + ? static_cast<const onert_tflite::ShapeOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::PowOptions *builtin_options_as_PowOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_PowOptions + ? static_cast<const onert_tflite::PowOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ArgMinOptions *builtin_options_as_ArgMinOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ArgMinOptions + ? static_cast<const onert_tflite::ArgMinOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::FakeQuantOptions *builtin_options_as_FakeQuantOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_FakeQuantOptions + ? static_cast<const onert_tflite::FakeQuantOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::PackOptions *builtin_options_as_PackOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_PackOptions + ? static_cast<const onert_tflite::PackOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::LogicalOrOptions *builtin_options_as_LogicalOrOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_LogicalOrOptions + ? static_cast<const onert_tflite::LogicalOrOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::OneHotOptions *builtin_options_as_OneHotOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_OneHotOptions + ? static_cast<const onert_tflite::OneHotOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::LogicalAndOptions *builtin_options_as_LogicalAndOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_LogicalAndOptions + ? static_cast<const onert_tflite::LogicalAndOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::LogicalNotOptions *builtin_options_as_LogicalNotOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_LogicalNotOptions + ? static_cast<const onert_tflite::LogicalNotOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::UnpackOptions *builtin_options_as_UnpackOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_UnpackOptions + ? static_cast<const onert_tflite::UnpackOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::FloorDivOptions *builtin_options_as_FloorDivOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_FloorDivOptions + ? static_cast<const onert_tflite::FloorDivOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SquareOptions *builtin_options_as_SquareOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SquareOptions + ? static_cast<const onert_tflite::SquareOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ZerosLikeOptions *builtin_options_as_ZerosLikeOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ZerosLikeOptions + ? static_cast<const onert_tflite::ZerosLikeOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::FillOptions *builtin_options_as_FillOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_FillOptions + ? static_cast<const onert_tflite::FillOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::BidirectionalSequenceLSTMOptions * + builtin_options_as_BidirectionalSequenceLSTMOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions + ? static_cast<const onert_tflite::BidirectionalSequenceLSTMOptions *>( + builtin_options()) + : nullptr; + } + const onert_tflite::BidirectionalSequenceRNNOptions * + builtin_options_as_BidirectionalSequenceRNNOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_BidirectionalSequenceRNNOptions + ? static_cast<const onert_tflite::BidirectionalSequenceRNNOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::UnidirectionalSequenceLSTMOptions * + builtin_options_as_UnidirectionalSequenceLSTMOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions + ? static_cast<const onert_tflite::UnidirectionalSequenceLSTMOptions *>( + builtin_options()) + : nullptr; + } + const onert_tflite::FloorModOptions *builtin_options_as_FloorModOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_FloorModOptions + ? static_cast<const onert_tflite::FloorModOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::RangeOptions *builtin_options_as_RangeOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_RangeOptions + ? static_cast<const onert_tflite::RangeOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ResizeNearestNeighborOptions * + builtin_options_as_ResizeNearestNeighborOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ResizeNearestNeighborOptions + ? static_cast<const onert_tflite::ResizeNearestNeighborOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::LeakyReluOptions *builtin_options_as_LeakyReluOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_LeakyReluOptions + ? static_cast<const onert_tflite::LeakyReluOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SquaredDifferenceOptions *builtin_options_as_SquaredDifferenceOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SquaredDifferenceOptions + ? static_cast<const onert_tflite::SquaredDifferenceOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::MirrorPadOptions *builtin_options_as_MirrorPadOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_MirrorPadOptions + ? static_cast<const onert_tflite::MirrorPadOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::AbsOptions *builtin_options_as_AbsOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_AbsOptions + ? static_cast<const onert_tflite::AbsOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SplitVOptions *builtin_options_as_SplitVOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SplitVOptions + ? static_cast<const onert_tflite::SplitVOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::UniqueOptions *builtin_options_as_UniqueOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_UniqueOptions + ? static_cast<const onert_tflite::UniqueOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ReverseV2Options *builtin_options_as_ReverseV2Options() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ReverseV2Options + ? static_cast<const onert_tflite::ReverseV2Options *>(builtin_options()) + : nullptr; + } + const onert_tflite::AddNOptions *builtin_options_as_AddNOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_AddNOptions + ? static_cast<const onert_tflite::AddNOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::GatherNdOptions *builtin_options_as_GatherNdOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_GatherNdOptions + ? static_cast<const onert_tflite::GatherNdOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::CosOptions *builtin_options_as_CosOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_CosOptions + ? static_cast<const onert_tflite::CosOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::WhereOptions *builtin_options_as_WhereOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_WhereOptions + ? static_cast<const onert_tflite::WhereOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::RankOptions *builtin_options_as_RankOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_RankOptions + ? static_cast<const onert_tflite::RankOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ReverseSequenceOptions *builtin_options_as_ReverseSequenceOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ReverseSequenceOptions + ? static_cast<const onert_tflite::ReverseSequenceOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::MatrixDiagOptions *builtin_options_as_MatrixDiagOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_MatrixDiagOptions + ? static_cast<const onert_tflite::MatrixDiagOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::QuantizeOptions *builtin_options_as_QuantizeOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_QuantizeOptions + ? static_cast<const onert_tflite::QuantizeOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::MatrixSetDiagOptions *builtin_options_as_MatrixSetDiagOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_MatrixSetDiagOptions + ? static_cast<const onert_tflite::MatrixSetDiagOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::HardSwishOptions *builtin_options_as_HardSwishOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_HardSwishOptions + ? static_cast<const onert_tflite::HardSwishOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::IfOptions *builtin_options_as_IfOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_IfOptions + ? static_cast<const onert_tflite::IfOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::WhileOptions *builtin_options_as_WhileOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_WhileOptions + ? static_cast<const onert_tflite::WhileOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::DepthToSpaceOptions *builtin_options_as_DepthToSpaceOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_DepthToSpaceOptions + ? static_cast<const onert_tflite::DepthToSpaceOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::NonMaxSuppressionV4Options * + builtin_options_as_NonMaxSuppressionV4Options() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_NonMaxSuppressionV4Options + ? static_cast<const onert_tflite::NonMaxSuppressionV4Options *>(builtin_options()) + : nullptr; + } + const onert_tflite::NonMaxSuppressionV5Options * + builtin_options_as_NonMaxSuppressionV5Options() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_NonMaxSuppressionV5Options + ? static_cast<const onert_tflite::NonMaxSuppressionV5Options *>(builtin_options()) + : nullptr; + } + const onert_tflite::ScatterNdOptions *builtin_options_as_ScatterNdOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ScatterNdOptions + ? static_cast<const onert_tflite::ScatterNdOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SelectV2Options *builtin_options_as_SelectV2Options() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SelectV2Options + ? static_cast<const onert_tflite::SelectV2Options *>(builtin_options()) + : nullptr; + } + const onert_tflite::DensifyOptions *builtin_options_as_DensifyOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_DensifyOptions + ? static_cast<const onert_tflite::DensifyOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::SegmentSumOptions *builtin_options_as_SegmentSumOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_SegmentSumOptions + ? static_cast<const onert_tflite::SegmentSumOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::BatchMatMulOptions *builtin_options_as_BatchMatMulOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_BatchMatMulOptions + ? static_cast<const onert_tflite::BatchMatMulOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::CumsumOptions *builtin_options_as_CumsumOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_CumsumOptions + ? static_cast<const onert_tflite::CumsumOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::CallOnceOptions *builtin_options_as_CallOnceOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_CallOnceOptions + ? static_cast<const onert_tflite::CallOnceOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::BroadcastToOptions *builtin_options_as_BroadcastToOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_BroadcastToOptions + ? static_cast<const onert_tflite::BroadcastToOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::Rfft2dOptions *builtin_options_as_Rfft2dOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_Rfft2dOptions + ? static_cast<const onert_tflite::Rfft2dOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::Conv3DOptions *builtin_options_as_Conv3DOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_Conv3DOptions + ? static_cast<const onert_tflite::Conv3DOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::HashtableOptions *builtin_options_as_HashtableOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_HashtableOptions + ? static_cast<const onert_tflite::HashtableOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::HashtableFindOptions *builtin_options_as_HashtableFindOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_HashtableFindOptions + ? static_cast<const onert_tflite::HashtableFindOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::HashtableImportOptions *builtin_options_as_HashtableImportOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_HashtableImportOptions + ? static_cast<const onert_tflite::HashtableImportOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::HashtableSizeOptions *builtin_options_as_HashtableSizeOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_HashtableSizeOptions + ? static_cast<const onert_tflite::HashtableSizeOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::VarHandleOptions *builtin_options_as_VarHandleOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_VarHandleOptions + ? static_cast<const onert_tflite::VarHandleOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ReadVariableOptions *builtin_options_as_ReadVariableOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ReadVariableOptions + ? static_cast<const onert_tflite::ReadVariableOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::AssignVariableOptions *builtin_options_as_AssignVariableOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_AssignVariableOptions + ? static_cast<const onert_tflite::AssignVariableOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::RandomOptions *builtin_options_as_RandomOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_RandomOptions + ? static_cast<const onert_tflite::RandomOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::BucketizeOptions *builtin_options_as_BucketizeOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_BucketizeOptions + ? static_cast<const onert_tflite::BucketizeOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::GeluOptions *builtin_options_as_GeluOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_GeluOptions + ? static_cast<const onert_tflite::GeluOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::DynamicUpdateSliceOptions * + builtin_options_as_DynamicUpdateSliceOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_DynamicUpdateSliceOptions + ? static_cast<const onert_tflite::DynamicUpdateSliceOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::UnsortedSegmentProdOptions * + builtin_options_as_UnsortedSegmentProdOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_UnsortedSegmentProdOptions + ? static_cast<const onert_tflite::UnsortedSegmentProdOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::UnsortedSegmentMaxOptions * + builtin_options_as_UnsortedSegmentMaxOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_UnsortedSegmentMaxOptions + ? static_cast<const onert_tflite::UnsortedSegmentMaxOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::UnsortedSegmentSumOptions * + builtin_options_as_UnsortedSegmentSumOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_UnsortedSegmentSumOptions + ? static_cast<const onert_tflite::UnsortedSegmentSumOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ATan2Options *builtin_options_as_ATan2Options() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ATan2Options + ? static_cast<const onert_tflite::ATan2Options *>(builtin_options()) + : nullptr; + } + const ::flatbuffers::Vector<uint8_t> *custom_options() const + { + return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); + } + onert_tflite::CustomOptionsFormat custom_options_format() const + { + return static_cast<onert_tflite::CustomOptionsFormat>( + GetField<int8_t>(VT_CUSTOM_OPTIONS_FORMAT, 0)); + } + const ::flatbuffers::Vector<uint8_t> *mutating_variable_inputs() const + { + return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_MUTATING_VARIABLE_INPUTS); + } + const ::flatbuffers::Vector<int32_t> *intermediates() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_INTERMEDIATES); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint32_t>(verifier, VT_OPCODE_INDEX, 4) && + VerifyOffset(verifier, VT_INPUTS) && verifier.VerifyVector(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && verifier.VerifyVector(outputs()) && + VerifyField<uint8_t>(verifier, VT_BUILTIN_OPTIONS_TYPE, 1) && + VerifyOffset(verifier, VT_BUILTIN_OPTIONS) && + VerifyBuiltinOptions(verifier, builtin_options(), builtin_options_type()) && + VerifyOffset(verifier, VT_CUSTOM_OPTIONS) && verifier.VerifyVector(custom_options()) && + VerifyField<int8_t>(verifier, VT_CUSTOM_OPTIONS_FORMAT, 1) && + VerifyOffset(verifier, VT_MUTATING_VARIABLE_INPUTS) && + verifier.VerifyVector(mutating_variable_inputs()) && + VerifyOffset(verifier, VT_INTERMEDIATES) && verifier.VerifyVector(intermediates()) && + verifier.EndTable(); + } +}; + +template <> +inline const onert_tflite::Conv2DOptions * +Operator::builtin_options_as<onert_tflite::Conv2DOptions>() const +{ + return builtin_options_as_Conv2DOptions(); +} + +template <> +inline const onert_tflite::DepthwiseConv2DOptions * +Operator::builtin_options_as<onert_tflite::DepthwiseConv2DOptions>() const +{ + return builtin_options_as_DepthwiseConv2DOptions(); +} + +template <> +inline const onert_tflite::ConcatEmbeddingsOptions * +Operator::builtin_options_as<onert_tflite::ConcatEmbeddingsOptions>() const +{ + return builtin_options_as_ConcatEmbeddingsOptions(); +} + +template <> +inline const onert_tflite::LSHProjectionOptions * +Operator::builtin_options_as<onert_tflite::LSHProjectionOptions>() const +{ + return builtin_options_as_LSHProjectionOptions(); +} + +template <> +inline const onert_tflite::Pool2DOptions * +Operator::builtin_options_as<onert_tflite::Pool2DOptions>() const +{ + return builtin_options_as_Pool2DOptions(); +} + +template <> +inline const onert_tflite::SVDFOptions * +Operator::builtin_options_as<onert_tflite::SVDFOptions>() const +{ + return builtin_options_as_SVDFOptions(); +} + +template <> +inline const onert_tflite::RNNOptions * +Operator::builtin_options_as<onert_tflite::RNNOptions>() const +{ + return builtin_options_as_RNNOptions(); +} + +template <> +inline const onert_tflite::FullyConnectedOptions * +Operator::builtin_options_as<onert_tflite::FullyConnectedOptions>() const +{ + return builtin_options_as_FullyConnectedOptions(); +} + +template <> +inline const onert_tflite::SoftmaxOptions * +Operator::builtin_options_as<onert_tflite::SoftmaxOptions>() const +{ + return builtin_options_as_SoftmaxOptions(); +} + +template <> +inline const onert_tflite::ConcatenationOptions * +Operator::builtin_options_as<onert_tflite::ConcatenationOptions>() const +{ + return builtin_options_as_ConcatenationOptions(); +} + +template <> +inline const onert_tflite::AddOptions * +Operator::builtin_options_as<onert_tflite::AddOptions>() const +{ + return builtin_options_as_AddOptions(); +} + +template <> +inline const onert_tflite::L2NormOptions * +Operator::builtin_options_as<onert_tflite::L2NormOptions>() const +{ + return builtin_options_as_L2NormOptions(); +} + +template <> +inline const onert_tflite::LocalResponseNormalizationOptions * +Operator::builtin_options_as<onert_tflite::LocalResponseNormalizationOptions>() const +{ + return builtin_options_as_LocalResponseNormalizationOptions(); +} + +template <> +inline const onert_tflite::LSTMOptions * +Operator::builtin_options_as<onert_tflite::LSTMOptions>() const +{ + return builtin_options_as_LSTMOptions(); +} + +template <> +inline const onert_tflite::ResizeBilinearOptions * +Operator::builtin_options_as<onert_tflite::ResizeBilinearOptions>() const +{ + return builtin_options_as_ResizeBilinearOptions(); +} + +template <> +inline const onert_tflite::CallOptions * +Operator::builtin_options_as<onert_tflite::CallOptions>() const +{ + return builtin_options_as_CallOptions(); +} + +template <> +inline const onert_tflite::ReshapeOptions * +Operator::builtin_options_as<onert_tflite::ReshapeOptions>() const +{ + return builtin_options_as_ReshapeOptions(); +} + +template <> +inline const onert_tflite::SkipGramOptions * +Operator::builtin_options_as<onert_tflite::SkipGramOptions>() const +{ + return builtin_options_as_SkipGramOptions(); +} + +template <> +inline const onert_tflite::SpaceToDepthOptions * +Operator::builtin_options_as<onert_tflite::SpaceToDepthOptions>() const +{ + return builtin_options_as_SpaceToDepthOptions(); +} + +template <> +inline const onert_tflite::EmbeddingLookupSparseOptions * +Operator::builtin_options_as<onert_tflite::EmbeddingLookupSparseOptions>() const +{ + return builtin_options_as_EmbeddingLookupSparseOptions(); +} + +template <> +inline const onert_tflite::MulOptions * +Operator::builtin_options_as<onert_tflite::MulOptions>() const +{ + return builtin_options_as_MulOptions(); +} + +template <> +inline const onert_tflite::PadOptions * +Operator::builtin_options_as<onert_tflite::PadOptions>() const +{ + return builtin_options_as_PadOptions(); +} + +template <> +inline const onert_tflite::GatherOptions * +Operator::builtin_options_as<onert_tflite::GatherOptions>() const +{ + return builtin_options_as_GatherOptions(); +} + +template <> +inline const onert_tflite::BatchToSpaceNDOptions * +Operator::builtin_options_as<onert_tflite::BatchToSpaceNDOptions>() const +{ + return builtin_options_as_BatchToSpaceNDOptions(); +} + +template <> +inline const onert_tflite::SpaceToBatchNDOptions * +Operator::builtin_options_as<onert_tflite::SpaceToBatchNDOptions>() const +{ + return builtin_options_as_SpaceToBatchNDOptions(); +} + +template <> +inline const onert_tflite::TransposeOptions * +Operator::builtin_options_as<onert_tflite::TransposeOptions>() const +{ + return builtin_options_as_TransposeOptions(); +} + +template <> +inline const onert_tflite::ReducerOptions * +Operator::builtin_options_as<onert_tflite::ReducerOptions>() const +{ + return builtin_options_as_ReducerOptions(); +} + +template <> +inline const onert_tflite::SubOptions * +Operator::builtin_options_as<onert_tflite::SubOptions>() const +{ + return builtin_options_as_SubOptions(); +} + +template <> +inline const onert_tflite::DivOptions * +Operator::builtin_options_as<onert_tflite::DivOptions>() const +{ + return builtin_options_as_DivOptions(); +} + +template <> +inline const onert_tflite::SqueezeOptions * +Operator::builtin_options_as<onert_tflite::SqueezeOptions>() const +{ + return builtin_options_as_SqueezeOptions(); +} + +template <> +inline const onert_tflite::SequenceRNNOptions * +Operator::builtin_options_as<onert_tflite::SequenceRNNOptions>() const +{ + return builtin_options_as_SequenceRNNOptions(); +} + +template <> +inline const onert_tflite::StridedSliceOptions * +Operator::builtin_options_as<onert_tflite::StridedSliceOptions>() const +{ + return builtin_options_as_StridedSliceOptions(); +} + +template <> +inline const onert_tflite::ExpOptions * +Operator::builtin_options_as<onert_tflite::ExpOptions>() const +{ + return builtin_options_as_ExpOptions(); +} + +template <> +inline const onert_tflite::TopKV2Options * +Operator::builtin_options_as<onert_tflite::TopKV2Options>() const +{ + return builtin_options_as_TopKV2Options(); +} + +template <> +inline const onert_tflite::SplitOptions * +Operator::builtin_options_as<onert_tflite::SplitOptions>() const +{ + return builtin_options_as_SplitOptions(); +} + +template <> +inline const onert_tflite::LogSoftmaxOptions * +Operator::builtin_options_as<onert_tflite::LogSoftmaxOptions>() const +{ + return builtin_options_as_LogSoftmaxOptions(); +} + +template <> +inline const onert_tflite::CastOptions * +Operator::builtin_options_as<onert_tflite::CastOptions>() const +{ + return builtin_options_as_CastOptions(); +} + +template <> +inline const onert_tflite::DequantizeOptions * +Operator::builtin_options_as<onert_tflite::DequantizeOptions>() const +{ + return builtin_options_as_DequantizeOptions(); +} + +template <> +inline const onert_tflite::MaximumMinimumOptions * +Operator::builtin_options_as<onert_tflite::MaximumMinimumOptions>() const +{ + return builtin_options_as_MaximumMinimumOptions(); +} + +template <> +inline const onert_tflite::ArgMaxOptions * +Operator::builtin_options_as<onert_tflite::ArgMaxOptions>() const +{ + return builtin_options_as_ArgMaxOptions(); +} + +template <> +inline const onert_tflite::LessOptions * +Operator::builtin_options_as<onert_tflite::LessOptions>() const +{ + return builtin_options_as_LessOptions(); +} + +template <> +inline const onert_tflite::NegOptions * +Operator::builtin_options_as<onert_tflite::NegOptions>() const +{ + return builtin_options_as_NegOptions(); +} + +template <> +inline const onert_tflite::PadV2Options * +Operator::builtin_options_as<onert_tflite::PadV2Options>() const +{ + return builtin_options_as_PadV2Options(); +} + +template <> +inline const onert_tflite::GreaterOptions * +Operator::builtin_options_as<onert_tflite::GreaterOptions>() const +{ + return builtin_options_as_GreaterOptions(); +} + +template <> +inline const onert_tflite::GreaterEqualOptions * +Operator::builtin_options_as<onert_tflite::GreaterEqualOptions>() const +{ + return builtin_options_as_GreaterEqualOptions(); +} + +template <> +inline const onert_tflite::LessEqualOptions * +Operator::builtin_options_as<onert_tflite::LessEqualOptions>() const +{ + return builtin_options_as_LessEqualOptions(); +} + +template <> +inline const onert_tflite::SelectOptions * +Operator::builtin_options_as<onert_tflite::SelectOptions>() const +{ + return builtin_options_as_SelectOptions(); +} + +template <> +inline const onert_tflite::SliceOptions * +Operator::builtin_options_as<onert_tflite::SliceOptions>() const +{ + return builtin_options_as_SliceOptions(); +} + +template <> +inline const onert_tflite::TransposeConvOptions * +Operator::builtin_options_as<onert_tflite::TransposeConvOptions>() const +{ + return builtin_options_as_TransposeConvOptions(); +} + +template <> +inline const onert_tflite::SparseToDenseOptions * +Operator::builtin_options_as<onert_tflite::SparseToDenseOptions>() const +{ + return builtin_options_as_SparseToDenseOptions(); +} + +template <> +inline const onert_tflite::TileOptions * +Operator::builtin_options_as<onert_tflite::TileOptions>() const +{ + return builtin_options_as_TileOptions(); +} + +template <> +inline const onert_tflite::ExpandDimsOptions * +Operator::builtin_options_as<onert_tflite::ExpandDimsOptions>() const +{ + return builtin_options_as_ExpandDimsOptions(); +} + +template <> +inline const onert_tflite::EqualOptions * +Operator::builtin_options_as<onert_tflite::EqualOptions>() const +{ + return builtin_options_as_EqualOptions(); +} + +template <> +inline const onert_tflite::NotEqualOptions * +Operator::builtin_options_as<onert_tflite::NotEqualOptions>() const +{ + return builtin_options_as_NotEqualOptions(); +} + +template <> +inline const onert_tflite::ShapeOptions * +Operator::builtin_options_as<onert_tflite::ShapeOptions>() const +{ + return builtin_options_as_ShapeOptions(); +} + +template <> +inline const onert_tflite::PowOptions * +Operator::builtin_options_as<onert_tflite::PowOptions>() const +{ + return builtin_options_as_PowOptions(); +} + +template <> +inline const onert_tflite::ArgMinOptions * +Operator::builtin_options_as<onert_tflite::ArgMinOptions>() const +{ + return builtin_options_as_ArgMinOptions(); +} + +template <> +inline const onert_tflite::FakeQuantOptions * +Operator::builtin_options_as<onert_tflite::FakeQuantOptions>() const +{ + return builtin_options_as_FakeQuantOptions(); +} + +template <> +inline const onert_tflite::PackOptions * +Operator::builtin_options_as<onert_tflite::PackOptions>() const +{ + return builtin_options_as_PackOptions(); +} + +template <> +inline const onert_tflite::LogicalOrOptions * +Operator::builtin_options_as<onert_tflite::LogicalOrOptions>() const +{ + return builtin_options_as_LogicalOrOptions(); +} + +template <> +inline const onert_tflite::OneHotOptions * +Operator::builtin_options_as<onert_tflite::OneHotOptions>() const +{ + return builtin_options_as_OneHotOptions(); +} + +template <> +inline const onert_tflite::LogicalAndOptions * +Operator::builtin_options_as<onert_tflite::LogicalAndOptions>() const +{ + return builtin_options_as_LogicalAndOptions(); +} + +template <> +inline const onert_tflite::LogicalNotOptions * +Operator::builtin_options_as<onert_tflite::LogicalNotOptions>() const +{ + return builtin_options_as_LogicalNotOptions(); +} + +template <> +inline const onert_tflite::UnpackOptions * +Operator::builtin_options_as<onert_tflite::UnpackOptions>() const +{ + return builtin_options_as_UnpackOptions(); +} + +template <> +inline const onert_tflite::FloorDivOptions * +Operator::builtin_options_as<onert_tflite::FloorDivOptions>() const +{ + return builtin_options_as_FloorDivOptions(); +} + +template <> +inline const onert_tflite::SquareOptions * +Operator::builtin_options_as<onert_tflite::SquareOptions>() const +{ + return builtin_options_as_SquareOptions(); +} + +template <> +inline const onert_tflite::ZerosLikeOptions * +Operator::builtin_options_as<onert_tflite::ZerosLikeOptions>() const +{ + return builtin_options_as_ZerosLikeOptions(); +} + +template <> +inline const onert_tflite::FillOptions * +Operator::builtin_options_as<onert_tflite::FillOptions>() const +{ + return builtin_options_as_FillOptions(); +} + +template <> +inline const onert_tflite::BidirectionalSequenceLSTMOptions * +Operator::builtin_options_as<onert_tflite::BidirectionalSequenceLSTMOptions>() const +{ + return builtin_options_as_BidirectionalSequenceLSTMOptions(); +} + +template <> +inline const onert_tflite::BidirectionalSequenceRNNOptions * +Operator::builtin_options_as<onert_tflite::BidirectionalSequenceRNNOptions>() const +{ + return builtin_options_as_BidirectionalSequenceRNNOptions(); +} + +template <> +inline const onert_tflite::UnidirectionalSequenceLSTMOptions * +Operator::builtin_options_as<onert_tflite::UnidirectionalSequenceLSTMOptions>() const +{ + return builtin_options_as_UnidirectionalSequenceLSTMOptions(); +} + +template <> +inline const onert_tflite::FloorModOptions * +Operator::builtin_options_as<onert_tflite::FloorModOptions>() const +{ + return builtin_options_as_FloorModOptions(); +} + +template <> +inline const onert_tflite::RangeOptions * +Operator::builtin_options_as<onert_tflite::RangeOptions>() const +{ + return builtin_options_as_RangeOptions(); +} + +template <> +inline const onert_tflite::ResizeNearestNeighborOptions * +Operator::builtin_options_as<onert_tflite::ResizeNearestNeighborOptions>() const +{ + return builtin_options_as_ResizeNearestNeighborOptions(); +} + +template <> +inline const onert_tflite::LeakyReluOptions * +Operator::builtin_options_as<onert_tflite::LeakyReluOptions>() const +{ + return builtin_options_as_LeakyReluOptions(); +} + +template <> +inline const onert_tflite::SquaredDifferenceOptions * +Operator::builtin_options_as<onert_tflite::SquaredDifferenceOptions>() const +{ + return builtin_options_as_SquaredDifferenceOptions(); +} + +template <> +inline const onert_tflite::MirrorPadOptions * +Operator::builtin_options_as<onert_tflite::MirrorPadOptions>() const +{ + return builtin_options_as_MirrorPadOptions(); +} + +template <> +inline const onert_tflite::AbsOptions * +Operator::builtin_options_as<onert_tflite::AbsOptions>() const +{ + return builtin_options_as_AbsOptions(); +} + +template <> +inline const onert_tflite::SplitVOptions * +Operator::builtin_options_as<onert_tflite::SplitVOptions>() const +{ + return builtin_options_as_SplitVOptions(); +} + +template <> +inline const onert_tflite::UniqueOptions * +Operator::builtin_options_as<onert_tflite::UniqueOptions>() const +{ + return builtin_options_as_UniqueOptions(); +} + +template <> +inline const onert_tflite::ReverseV2Options * +Operator::builtin_options_as<onert_tflite::ReverseV2Options>() const +{ + return builtin_options_as_ReverseV2Options(); +} + +template <> +inline const onert_tflite::AddNOptions * +Operator::builtin_options_as<onert_tflite::AddNOptions>() const +{ + return builtin_options_as_AddNOptions(); +} + +template <> +inline const onert_tflite::GatherNdOptions * +Operator::builtin_options_as<onert_tflite::GatherNdOptions>() const +{ + return builtin_options_as_GatherNdOptions(); +} + +template <> +inline const onert_tflite::CosOptions * +Operator::builtin_options_as<onert_tflite::CosOptions>() const +{ + return builtin_options_as_CosOptions(); +} + +template <> +inline const onert_tflite::WhereOptions * +Operator::builtin_options_as<onert_tflite::WhereOptions>() const +{ + return builtin_options_as_WhereOptions(); +} + +template <> +inline const onert_tflite::RankOptions * +Operator::builtin_options_as<onert_tflite::RankOptions>() const +{ + return builtin_options_as_RankOptions(); +} + +template <> +inline const onert_tflite::ReverseSequenceOptions * +Operator::builtin_options_as<onert_tflite::ReverseSequenceOptions>() const +{ + return builtin_options_as_ReverseSequenceOptions(); +} + +template <> +inline const onert_tflite::MatrixDiagOptions * +Operator::builtin_options_as<onert_tflite::MatrixDiagOptions>() const +{ + return builtin_options_as_MatrixDiagOptions(); +} + +template <> +inline const onert_tflite::QuantizeOptions * +Operator::builtin_options_as<onert_tflite::QuantizeOptions>() const +{ + return builtin_options_as_QuantizeOptions(); +} + +template <> +inline const onert_tflite::MatrixSetDiagOptions * +Operator::builtin_options_as<onert_tflite::MatrixSetDiagOptions>() const +{ + return builtin_options_as_MatrixSetDiagOptions(); +} + +template <> +inline const onert_tflite::HardSwishOptions * +Operator::builtin_options_as<onert_tflite::HardSwishOptions>() const +{ + return builtin_options_as_HardSwishOptions(); +} + +template <> +inline const onert_tflite::IfOptions *Operator::builtin_options_as<onert_tflite::IfOptions>() const +{ + return builtin_options_as_IfOptions(); +} + +template <> +inline const onert_tflite::WhileOptions * +Operator::builtin_options_as<onert_tflite::WhileOptions>() const +{ + return builtin_options_as_WhileOptions(); +} + +template <> +inline const onert_tflite::DepthToSpaceOptions * +Operator::builtin_options_as<onert_tflite::DepthToSpaceOptions>() const +{ + return builtin_options_as_DepthToSpaceOptions(); +} + +template <> +inline const onert_tflite::NonMaxSuppressionV4Options * +Operator::builtin_options_as<onert_tflite::NonMaxSuppressionV4Options>() const +{ + return builtin_options_as_NonMaxSuppressionV4Options(); +} + +template <> +inline const onert_tflite::NonMaxSuppressionV5Options * +Operator::builtin_options_as<onert_tflite::NonMaxSuppressionV5Options>() const +{ + return builtin_options_as_NonMaxSuppressionV5Options(); +} + +template <> +inline const onert_tflite::ScatterNdOptions * +Operator::builtin_options_as<onert_tflite::ScatterNdOptions>() const +{ + return builtin_options_as_ScatterNdOptions(); +} + +template <> +inline const onert_tflite::SelectV2Options * +Operator::builtin_options_as<onert_tflite::SelectV2Options>() const +{ + return builtin_options_as_SelectV2Options(); +} + +template <> +inline const onert_tflite::DensifyOptions * +Operator::builtin_options_as<onert_tflite::DensifyOptions>() const +{ + return builtin_options_as_DensifyOptions(); +} + +template <> +inline const onert_tflite::SegmentSumOptions * +Operator::builtin_options_as<onert_tflite::SegmentSumOptions>() const +{ + return builtin_options_as_SegmentSumOptions(); +} + +template <> +inline const onert_tflite::BatchMatMulOptions * +Operator::builtin_options_as<onert_tflite::BatchMatMulOptions>() const +{ + return builtin_options_as_BatchMatMulOptions(); +} + +template <> +inline const onert_tflite::CumsumOptions * +Operator::builtin_options_as<onert_tflite::CumsumOptions>() const +{ + return builtin_options_as_CumsumOptions(); +} + +template <> +inline const onert_tflite::CallOnceOptions * +Operator::builtin_options_as<onert_tflite::CallOnceOptions>() const +{ + return builtin_options_as_CallOnceOptions(); +} + +template <> +inline const onert_tflite::BroadcastToOptions * +Operator::builtin_options_as<onert_tflite::BroadcastToOptions>() const +{ + return builtin_options_as_BroadcastToOptions(); +} + +template <> +inline const onert_tflite::Rfft2dOptions * +Operator::builtin_options_as<onert_tflite::Rfft2dOptions>() const +{ + return builtin_options_as_Rfft2dOptions(); +} + +template <> +inline const onert_tflite::Conv3DOptions * +Operator::builtin_options_as<onert_tflite::Conv3DOptions>() const +{ + return builtin_options_as_Conv3DOptions(); +} + +template <> +inline const onert_tflite::HashtableOptions * +Operator::builtin_options_as<onert_tflite::HashtableOptions>() const +{ + return builtin_options_as_HashtableOptions(); +} + +template <> +inline const onert_tflite::HashtableFindOptions * +Operator::builtin_options_as<onert_tflite::HashtableFindOptions>() const +{ + return builtin_options_as_HashtableFindOptions(); +} + +template <> +inline const onert_tflite::HashtableImportOptions * +Operator::builtin_options_as<onert_tflite::HashtableImportOptions>() const +{ + return builtin_options_as_HashtableImportOptions(); +} + +template <> +inline const onert_tflite::HashtableSizeOptions * +Operator::builtin_options_as<onert_tflite::HashtableSizeOptions>() const +{ + return builtin_options_as_HashtableSizeOptions(); +} + +template <> +inline const onert_tflite::VarHandleOptions * +Operator::builtin_options_as<onert_tflite::VarHandleOptions>() const +{ + return builtin_options_as_VarHandleOptions(); +} + +template <> +inline const onert_tflite::ReadVariableOptions * +Operator::builtin_options_as<onert_tflite::ReadVariableOptions>() const +{ + return builtin_options_as_ReadVariableOptions(); +} + +template <> +inline const onert_tflite::AssignVariableOptions * +Operator::builtin_options_as<onert_tflite::AssignVariableOptions>() const +{ + return builtin_options_as_AssignVariableOptions(); +} + +template <> +inline const onert_tflite::RandomOptions * +Operator::builtin_options_as<onert_tflite::RandomOptions>() const +{ + return builtin_options_as_RandomOptions(); +} + +template <> +inline const onert_tflite::BucketizeOptions * +Operator::builtin_options_as<onert_tflite::BucketizeOptions>() const +{ + return builtin_options_as_BucketizeOptions(); +} + +template <> +inline const onert_tflite::GeluOptions * +Operator::builtin_options_as<onert_tflite::GeluOptions>() const +{ + return builtin_options_as_GeluOptions(); +} + +template <> +inline const onert_tflite::DynamicUpdateSliceOptions * +Operator::builtin_options_as<onert_tflite::DynamicUpdateSliceOptions>() const +{ + return builtin_options_as_DynamicUpdateSliceOptions(); +} + +template <> +inline const onert_tflite::UnsortedSegmentProdOptions * +Operator::builtin_options_as<onert_tflite::UnsortedSegmentProdOptions>() const +{ + return builtin_options_as_UnsortedSegmentProdOptions(); +} + +template <> +inline const onert_tflite::UnsortedSegmentMaxOptions * +Operator::builtin_options_as<onert_tflite::UnsortedSegmentMaxOptions>() const +{ + return builtin_options_as_UnsortedSegmentMaxOptions(); +} + +template <> +inline const onert_tflite::UnsortedSegmentSumOptions * +Operator::builtin_options_as<onert_tflite::UnsortedSegmentSumOptions>() const +{ + return builtin_options_as_UnsortedSegmentSumOptions(); +} + +template <> +inline const onert_tflite::ATan2Options * +Operator::builtin_options_as<onert_tflite::ATan2Options>() const +{ + return builtin_options_as_ATan2Options(); +} + +struct OperatorBuilder +{ + typedef Operator Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_opcode_index(uint32_t opcode_index) + { + fbb_.AddElement<uint32_t>(Operator::VT_OPCODE_INDEX, opcode_index, 0); + } + void add_inputs(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> inputs) + { + fbb_.AddOffset(Operator::VT_INPUTS, inputs); + } + void add_outputs(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> outputs) + { + fbb_.AddOffset(Operator::VT_OUTPUTS, outputs); + } + void add_builtin_options_type(onert_tflite::BuiltinOptions builtin_options_type) + { + fbb_.AddElement<uint8_t>(Operator::VT_BUILTIN_OPTIONS_TYPE, + static_cast<uint8_t>(builtin_options_type), 0); + } + void add_builtin_options(::flatbuffers::Offset<void> builtin_options) + { + fbb_.AddOffset(Operator::VT_BUILTIN_OPTIONS, builtin_options); + } + void add_custom_options(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> custom_options) + { + fbb_.AddOffset(Operator::VT_CUSTOM_OPTIONS, custom_options); + } + void add_custom_options_format(onert_tflite::CustomOptionsFormat custom_options_format) + { + fbb_.AddElement<int8_t>(Operator::VT_CUSTOM_OPTIONS_FORMAT, + static_cast<int8_t>(custom_options_format), 0); + } + void add_mutating_variable_inputs( + ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> mutating_variable_inputs) + { + fbb_.AddOffset(Operator::VT_MUTATING_VARIABLE_INPUTS, mutating_variable_inputs); + } + void add_intermediates(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> intermediates) + { + fbb_.AddOffset(Operator::VT_INTERMEDIATES, intermediates); + } + explicit OperatorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<Operator> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<Operator>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<Operator> CreateOperator( + ::flatbuffers::FlatBufferBuilder &_fbb, uint32_t opcode_index = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> inputs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> outputs = 0, + onert_tflite::BuiltinOptions builtin_options_type = onert_tflite::BuiltinOptions_NONE, + ::flatbuffers::Offset<void> builtin_options = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> custom_options = 0, + onert_tflite::CustomOptionsFormat custom_options_format = + onert_tflite::CustomOptionsFormat_FLEXBUFFERS, + ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> mutating_variable_inputs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> intermediates = 0) +{ + OperatorBuilder builder_(_fbb); + builder_.add_intermediates(intermediates); + builder_.add_mutating_variable_inputs(mutating_variable_inputs); + builder_.add_custom_options(custom_options); + builder_.add_builtin_options(builtin_options); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_opcode_index(opcode_index); + builder_.add_custom_options_format(custom_options_format); + builder_.add_builtin_options_type(builtin_options_type); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<Operator> CreateOperatorDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, uint32_t opcode_index = 0, + const std::vector<int32_t> *inputs = nullptr, const std::vector<int32_t> *outputs = nullptr, + onert_tflite::BuiltinOptions builtin_options_type = onert_tflite::BuiltinOptions_NONE, + ::flatbuffers::Offset<void> builtin_options = 0, + const std::vector<uint8_t> *custom_options = nullptr, + onert_tflite::CustomOptionsFormat custom_options_format = + onert_tflite::CustomOptionsFormat_FLEXBUFFERS, + const std::vector<uint8_t> *mutating_variable_inputs = nullptr, + const std::vector<int32_t> *intermediates = nullptr) +{ + auto inputs__ = inputs ? _fbb.CreateVector<int32_t>(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector<int32_t>(*outputs) : 0; + auto custom_options__ = custom_options ? _fbb.CreateVector<uint8_t>(*custom_options) : 0; + auto mutating_variable_inputs__ = + mutating_variable_inputs ? _fbb.CreateVector<uint8_t>(*mutating_variable_inputs) : 0; + auto intermediates__ = intermediates ? _fbb.CreateVector<int32_t>(*intermediates) : 0; + return onert_tflite::CreateOperator(_fbb, opcode_index, inputs__, outputs__, builtin_options_type, + builtin_options, custom_options__, custom_options_format, + mutating_variable_inputs__, intermediates__); +} + +struct SubGraph FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SubGraphBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_TENSORS = 4, + VT_INPUTS = 6, + VT_OUTPUTS = 8, + VT_OPERATORS = 10, + VT_NAME = 12 + }; + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Tensor>> *tensors() const + { + return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Tensor>> *>( + VT_TENSORS); + } + const ::flatbuffers::Vector<int32_t> *inputs() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_INPUTS); + } + const ::flatbuffers::Vector<int32_t> *outputs() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_OUTPUTS); + } + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Operator>> *operators() const + { + return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Operator>> *>( + VT_OPERATORS); + } + const ::flatbuffers::String *name() const + { + return GetPointer<const ::flatbuffers::String *>(VT_NAME); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_TENSORS) && + verifier.VerifyVector(tensors()) && verifier.VerifyVectorOfTables(tensors()) && + VerifyOffset(verifier, VT_INPUTS) && verifier.VerifyVector(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && verifier.VerifyVector(outputs()) && + VerifyOffset(verifier, VT_OPERATORS) && verifier.VerifyVector(operators()) && + verifier.VerifyVectorOfTables(operators()) && VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && verifier.EndTable(); + } +}; + +struct SubGraphBuilder +{ + typedef SubGraph Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_tensors( + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Tensor>>> + tensors) + { + fbb_.AddOffset(SubGraph::VT_TENSORS, tensors); + } + void add_inputs(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> inputs) + { + fbb_.AddOffset(SubGraph::VT_INPUTS, inputs); + } + void add_outputs(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> outputs) + { + fbb_.AddOffset(SubGraph::VT_OUTPUTS, outputs); + } + void add_operators( + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Operator>>> + operators) + { + fbb_.AddOffset(SubGraph::VT_OPERATORS, operators); + } + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) + { + fbb_.AddOffset(SubGraph::VT_NAME, name); + } + explicit SubGraphBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SubGraph> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SubGraph>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SubGraph> CreateSubGraph( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Tensor>>> + tensors = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> inputs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> outputs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Operator>>> + operators = 0, + ::flatbuffers::Offset<::flatbuffers::String> name = 0) +{ + SubGraphBuilder builder_(_fbb); + builder_.add_name(name); + builder_.add_operators(operators); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_tensors(tensors); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<SubGraph> CreateSubGraphDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<::flatbuffers::Offset<onert_tflite::Tensor>> *tensors = nullptr, + const std::vector<int32_t> *inputs = nullptr, const std::vector<int32_t> *outputs = nullptr, + const std::vector<::flatbuffers::Offset<onert_tflite::Operator>> *operators = nullptr, + const char *name = nullptr) +{ + auto tensors__ = + tensors ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::Tensor>>(*tensors) : 0; + auto inputs__ = inputs ? _fbb.CreateVector<int32_t>(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector<int32_t>(*outputs) : 0; + auto operators__ = + operators ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::Operator>>(*operators) : 0; + auto name__ = name ? _fbb.CreateString(name) : 0; + return onert_tflite::CreateSubGraph(_fbb, tensors__, inputs__, outputs__, operators__, name__); +} + +struct Buffer FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef BufferBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_DATA = 4 + }; + const ::flatbuffers::Vector<uint8_t> *data() const + { + return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_DATA); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && verifier.EndTable(); + } +}; + +struct BufferBuilder +{ + typedef Buffer Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_data(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data) + { + fbb_.AddOffset(Buffer::VT_DATA, data); + } + explicit BufferBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<Buffer> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<Buffer>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<Buffer> +CreateBuffer(::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data = 0) +{ + BufferBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<Buffer> CreateBufferDirect(::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<uint8_t> *data = nullptr) +{ + if (data) + { + _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 16); + } + auto data__ = data ? _fbb.CreateVector<uint8_t>(*data) : 0; + return onert_tflite::CreateBuffer(_fbb, data__); +} + +struct Metadata FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef MetadataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_NAME = 4, + VT_BUFFER = 6 + }; + const ::flatbuffers::String *name() const + { + return GetPointer<const ::flatbuffers::String *>(VT_NAME); + } + uint32_t buffer() const { return GetField<uint32_t>(VT_BUFFER, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && VerifyField<uint32_t>(verifier, VT_BUFFER, 4) && + verifier.EndTable(); + } +}; + +struct MetadataBuilder +{ + typedef Metadata Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) + { + fbb_.AddOffset(Metadata::VT_NAME, name); + } + void add_buffer(uint32_t buffer) { fbb_.AddElement<uint32_t>(Metadata::VT_BUFFER, buffer, 0); } + explicit MetadataBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<Metadata> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<Metadata>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<Metadata> +CreateMetadata(::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> name = 0, uint32_t buffer = 0) +{ + MetadataBuilder builder_(_fbb); + builder_.add_buffer(buffer); + builder_.add_name(name); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<Metadata> CreateMetadataDirect(::flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + uint32_t buffer = 0) +{ + auto name__ = name ? _fbb.CreateString(name) : 0; + return onert_tflite::CreateMetadata(_fbb, name__, buffer); +} + +struct TensorMap FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef TensorMapBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_NAME = 4, + VT_TENSOR_INDEX = 6 + }; + const ::flatbuffers::String *name() const + { + return GetPointer<const ::flatbuffers::String *>(VT_NAME); + } + uint32_t tensor_index() const { return GetField<uint32_t>(VT_TENSOR_INDEX, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && VerifyField<uint32_t>(verifier, VT_TENSOR_INDEX, 4) && + verifier.EndTable(); + } +}; + +struct TensorMapBuilder +{ + typedef TensorMap Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) + { + fbb_.AddOffset(TensorMap::VT_NAME, name); + } + void add_tensor_index(uint32_t tensor_index) + { + fbb_.AddElement<uint32_t>(TensorMap::VT_TENSOR_INDEX, tensor_index, 0); + } + explicit TensorMapBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<TensorMap> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<TensorMap>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<TensorMap> +CreateTensorMap(::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> name = 0, uint32_t tensor_index = 0) +{ + TensorMapBuilder builder_(_fbb); + builder_.add_tensor_index(tensor_index); + builder_.add_name(name); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<TensorMap> +CreateTensorMapDirect(::flatbuffers::FlatBufferBuilder &_fbb, const char *name = nullptr, + uint32_t tensor_index = 0) +{ + auto name__ = name ? _fbb.CreateString(name) : 0; + return onert_tflite::CreateTensorMap(_fbb, name__, tensor_index); +} + +struct SignatureDef FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef SignatureDefBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_INPUTS = 4, + VT_OUTPUTS = 6, + VT_SIGNATURE_KEY = 8, + VT_SUBGRAPH_INDEX = 12 + }; + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>> *inputs() const + { + return GetPointer< + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>> *>(VT_INPUTS); + } + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>> *outputs() const + { + return GetPointer< + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>> *>(VT_OUTPUTS); + } + const ::flatbuffers::String *signature_key() const + { + return GetPointer<const ::flatbuffers::String *>(VT_SIGNATURE_KEY); + } + uint32_t subgraph_index() const { return GetField<uint32_t>(VT_SUBGRAPH_INDEX, 0); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_INPUTS) && + verifier.VerifyVector(inputs()) && verifier.VerifyVectorOfTables(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && verifier.VerifyVector(outputs()) && + verifier.VerifyVectorOfTables(outputs()) && VerifyOffset(verifier, VT_SIGNATURE_KEY) && + verifier.VerifyString(signature_key()) && + VerifyField<uint32_t>(verifier, VT_SUBGRAPH_INDEX, 4) && verifier.EndTable(); + } +}; + +struct SignatureDefBuilder +{ + typedef SignatureDef Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_inputs( + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>>> + inputs) + { + fbb_.AddOffset(SignatureDef::VT_INPUTS, inputs); + } + void add_outputs( + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>>> + outputs) + { + fbb_.AddOffset(SignatureDef::VT_OUTPUTS, outputs); + } + void add_signature_key(::flatbuffers::Offset<::flatbuffers::String> signature_key) + { + fbb_.AddOffset(SignatureDef::VT_SIGNATURE_KEY, signature_key); + } + void add_subgraph_index(uint32_t subgraph_index) + { + fbb_.AddElement<uint32_t>(SignatureDef::VT_SUBGRAPH_INDEX, subgraph_index, 0); + } + explicit SignatureDefBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<SignatureDef> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<SignatureDef>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<SignatureDef> CreateSignatureDef( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>>> + inputs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>>> + outputs = 0, + ::flatbuffers::Offset<::flatbuffers::String> signature_key = 0, uint32_t subgraph_index = 0) +{ + SignatureDefBuilder builder_(_fbb); + builder_.add_subgraph_index(subgraph_index); + builder_.add_signature_key(signature_key); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<SignatureDef> CreateSignatureDefDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<::flatbuffers::Offset<onert_tflite::TensorMap>> *inputs = nullptr, + const std::vector<::flatbuffers::Offset<onert_tflite::TensorMap>> *outputs = nullptr, + const char *signature_key = nullptr, uint32_t subgraph_index = 0) +{ + auto inputs__ = + inputs ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::TensorMap>>(*inputs) : 0; + auto outputs__ = + outputs ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::TensorMap>>(*outputs) : 0; + auto signature_key__ = signature_key ? _fbb.CreateString(signature_key) : 0; + return onert_tflite::CreateSignatureDef(_fbb, inputs__, outputs__, signature_key__, + subgraph_index); +} + +struct Model FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef ModelBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_VERSION = 4, + VT_OPERATOR_CODES = 6, + VT_SUBGRAPHS = 8, + VT_DESCRIPTION = 10, + VT_BUFFERS = 12, + VT_METADATA_BUFFER = 14, + VT_METADATA = 16, + VT_SIGNATURE_DEFS = 18 + }; + uint32_t version() const { return GetField<uint32_t>(VT_VERSION, 0); } + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::OperatorCode>> * + operator_codes() const + { + return GetPointer< + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::OperatorCode>> *>( + VT_OPERATOR_CODES); + } + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SubGraph>> *subgraphs() const + { + return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SubGraph>> *>( + VT_SUBGRAPHS); + } + const ::flatbuffers::String *description() const + { + return GetPointer<const ::flatbuffers::String *>(VT_DESCRIPTION); + } + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Buffer>> *buffers() const + { + return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Buffer>> *>( + VT_BUFFERS); + } + const ::flatbuffers::Vector<int32_t> *metadata_buffer() const + { + return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_METADATA_BUFFER); + } + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Metadata>> *metadata() const + { + return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Metadata>> *>( + VT_METADATA); + } + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SignatureDef>> * + signature_defs() const + { + return GetPointer< + const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SignatureDef>> *>( + VT_SIGNATURE_DEFS); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint32_t>(verifier, VT_VERSION, 4) && + VerifyOffset(verifier, VT_OPERATOR_CODES) && verifier.VerifyVector(operator_codes()) && + verifier.VerifyVectorOfTables(operator_codes()) && + VerifyOffset(verifier, VT_SUBGRAPHS) && verifier.VerifyVector(subgraphs()) && + verifier.VerifyVectorOfTables(subgraphs()) && VerifyOffset(verifier, VT_DESCRIPTION) && + verifier.VerifyString(description()) && VerifyOffset(verifier, VT_BUFFERS) && + verifier.VerifyVector(buffers()) && verifier.VerifyVectorOfTables(buffers()) && + VerifyOffset(verifier, VT_METADATA_BUFFER) && verifier.VerifyVector(metadata_buffer()) && + VerifyOffset(verifier, VT_METADATA) && verifier.VerifyVector(metadata()) && + verifier.VerifyVectorOfTables(metadata()) && VerifyOffset(verifier, VT_SIGNATURE_DEFS) && + verifier.VerifyVector(signature_defs()) && + verifier.VerifyVectorOfTables(signature_defs()) && verifier.EndTable(); + } +}; + +struct ModelBuilder +{ + typedef Model Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_version(uint32_t version) { fbb_.AddElement<uint32_t>(Model::VT_VERSION, version, 0); } + void add_operator_codes( + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::OperatorCode>>> + operator_codes) + { + fbb_.AddOffset(Model::VT_OPERATOR_CODES, operator_codes); + } + void add_subgraphs( + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SubGraph>>> + subgraphs) + { + fbb_.AddOffset(Model::VT_SUBGRAPHS, subgraphs); + } + void add_description(::flatbuffers::Offset<::flatbuffers::String> description) + { + fbb_.AddOffset(Model::VT_DESCRIPTION, description); + } + void add_buffers( + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Buffer>>> + buffers) + { + fbb_.AddOffset(Model::VT_BUFFERS, buffers); + } + void add_metadata_buffer(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> metadata_buffer) + { + fbb_.AddOffset(Model::VT_METADATA_BUFFER, metadata_buffer); + } + void add_metadata( + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Metadata>>> + metadata) + { + fbb_.AddOffset(Model::VT_METADATA, metadata); + } + void add_signature_defs( + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SignatureDef>>> + signature_defs) + { + fbb_.AddOffset(Model::VT_SIGNATURE_DEFS, signature_defs); + } + explicit ModelBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset<Model> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset<Model>(end); + return o; + } +}; + +inline ::flatbuffers::Offset<Model> CreateModel( + ::flatbuffers::FlatBufferBuilder &_fbb, uint32_t version = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::OperatorCode>>> + operator_codes = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SubGraph>>> + subgraphs = 0, + ::flatbuffers::Offset<::flatbuffers::String> description = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Buffer>>> + buffers = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> metadata_buffer = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Metadata>>> + metadata = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SignatureDef>>> + signature_defs = 0) +{ + ModelBuilder builder_(_fbb); + builder_.add_signature_defs(signature_defs); + builder_.add_metadata(metadata); + builder_.add_metadata_buffer(metadata_buffer); + builder_.add_buffers(buffers); + builder_.add_description(description); + builder_.add_subgraphs(subgraphs); + builder_.add_operator_codes(operator_codes); + builder_.add_version(version); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset<Model> CreateModelDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, uint32_t version = 0, + const std::vector<::flatbuffers::Offset<onert_tflite::OperatorCode>> *operator_codes = nullptr, + const std::vector<::flatbuffers::Offset<onert_tflite::SubGraph>> *subgraphs = nullptr, + const char *description = nullptr, + const std::vector<::flatbuffers::Offset<onert_tflite::Buffer>> *buffers = nullptr, + const std::vector<int32_t> *metadata_buffer = nullptr, + const std::vector<::flatbuffers::Offset<onert_tflite::Metadata>> *metadata = nullptr, + const std::vector<::flatbuffers::Offset<onert_tflite::SignatureDef>> *signature_defs = nullptr) +{ + auto operator_codes__ = + operator_codes + ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::OperatorCode>>(*operator_codes) + : 0; + auto subgraphs__ = + subgraphs ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::SubGraph>>(*subgraphs) : 0; + auto description__ = description ? _fbb.CreateString(description) : 0; + auto buffers__ = + buffers ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::Buffer>>(*buffers) : 0; + auto metadata_buffer__ = metadata_buffer ? _fbb.CreateVector<int32_t>(*metadata_buffer) : 0; + auto metadata__ = + metadata ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::Metadata>>(*metadata) : 0; + auto signature_defs__ = + signature_defs + ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::SignatureDef>>(*signature_defs) + : 0; + return onert_tflite::CreateModel(_fbb, version, operator_codes__, subgraphs__, description__, + buffers__, metadata_buffer__, metadata__, signature_defs__); +} + +inline bool VerifyQuantizationDetails(::flatbuffers::Verifier &verifier, const void *obj, + QuantizationDetails type) +{ + switch (type) + { + case QuantizationDetails_NONE: + { + return true; + } + case QuantizationDetails_CustomQuantization: + { + auto ptr = reinterpret_cast<const onert_tflite::CustomQuantization *>(obj); + return verifier.VerifyTable(ptr); + } + default: + return true; + } +} + +inline bool +VerifyQuantizationDetailsVector(::flatbuffers::Verifier &verifier, + const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, + const ::flatbuffers::Vector<uint8_t> *types) +{ + if (!values || !types) + return !values && !types; + if (values->size() != types->size()) + return false; + for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) + { + if (!VerifyQuantizationDetails(verifier, values->Get(i), + types->GetEnum<QuantizationDetails>(i))) + { + return false; + } + } + return true; +} + +inline bool VerifySparseIndexVector(::flatbuffers::Verifier &verifier, const void *obj, + SparseIndexVector type) +{ + switch (type) + { + case SparseIndexVector_NONE: + { + return true; + } + case SparseIndexVector_Int32Vector: + { + auto ptr = reinterpret_cast<const onert_tflite::Int32Vector *>(obj); + return verifier.VerifyTable(ptr); + } + case SparseIndexVector_Uint16Vector: + { + auto ptr = reinterpret_cast<const onert_tflite::Uint16Vector *>(obj); + return verifier.VerifyTable(ptr); + } + case SparseIndexVector_Uint8Vector: + { + auto ptr = reinterpret_cast<const onert_tflite::Uint8Vector *>(obj); + return verifier.VerifyTable(ptr); + } + default: + return true; + } +} + +inline bool +VerifySparseIndexVectorVector(::flatbuffers::Verifier &verifier, + const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, + const ::flatbuffers::Vector<uint8_t> *types) +{ + if (!values || !types) + return !values && !types; + if (values->size() != types->size()) + return false; + for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) + { + if (!VerifySparseIndexVector(verifier, values->Get(i), types->GetEnum<SparseIndexVector>(i))) + { + return false; + } + } + return true; +} + +inline bool VerifyBuiltinOptions(::flatbuffers::Verifier &verifier, const void *obj, + BuiltinOptions type) +{ + switch (type) + { + case BuiltinOptions_NONE: + { + return true; + } + case BuiltinOptions_Conv2DOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::Conv2DOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DepthwiseConv2DOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::DepthwiseConv2DOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ConcatEmbeddingsOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ConcatEmbeddingsOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LSHProjectionOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::LSHProjectionOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_Pool2DOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::Pool2DOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SVDFOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SVDFOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RNNOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::RNNOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FullyConnectedOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::FullyConnectedOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SoftmaxOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SoftmaxOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ConcatenationOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ConcatenationOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_AddOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::AddOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_L2NormOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::L2NormOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LocalResponseNormalizationOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::LocalResponseNormalizationOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LSTMOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::LSTMOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ResizeBilinearOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ResizeBilinearOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_CallOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::CallOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ReshapeOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ReshapeOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SkipGramOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SkipGramOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SpaceToDepthOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SpaceToDepthOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_EmbeddingLookupSparseOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::EmbeddingLookupSparseOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_MulOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::MulOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_PadOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::PadOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GatherOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::GatherOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BatchToSpaceNDOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::BatchToSpaceNDOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SpaceToBatchNDOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SpaceToBatchNDOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TransposeOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::TransposeOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ReducerOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ReducerOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SubOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SubOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DivOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::DivOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SqueezeOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SqueezeOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SequenceRNNOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SequenceRNNOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_StridedSliceOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::StridedSliceOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ExpOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ExpOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TopKV2Options: + { + auto ptr = reinterpret_cast<const onert_tflite::TopKV2Options *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SplitOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SplitOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogSoftmaxOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::LogSoftmaxOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_CastOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::CastOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DequantizeOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::DequantizeOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_MaximumMinimumOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::MaximumMinimumOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ArgMaxOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ArgMaxOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LessOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::LessOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NegOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::NegOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_PadV2Options: + { + auto ptr = reinterpret_cast<const onert_tflite::PadV2Options *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GreaterOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::GreaterOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GreaterEqualOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::GreaterEqualOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LessEqualOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::LessEqualOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SelectOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SelectOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SliceOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SliceOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TransposeConvOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::TransposeConvOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SparseToDenseOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SparseToDenseOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TileOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::TileOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ExpandDimsOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ExpandDimsOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_EqualOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::EqualOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NotEqualOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::NotEqualOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ShapeOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ShapeOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_PowOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::PowOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ArgMinOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ArgMinOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FakeQuantOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::FakeQuantOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_PackOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::PackOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalOrOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::LogicalOrOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_OneHotOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::OneHotOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalAndOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::LogicalAndOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalNotOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::LogicalNotOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnpackOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::UnpackOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FloorDivOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::FloorDivOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SquareOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SquareOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ZerosLikeOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ZerosLikeOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FillOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::FillOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BidirectionalSequenceLSTMOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::BidirectionalSequenceLSTMOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BidirectionalSequenceRNNOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::BidirectionalSequenceRNNOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnidirectionalSequenceLSTMOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::UnidirectionalSequenceLSTMOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FloorModOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::FloorModOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RangeOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::RangeOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ResizeNearestNeighborOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ResizeNearestNeighborOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LeakyReluOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::LeakyReluOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SquaredDifferenceOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SquaredDifferenceOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_MirrorPadOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::MirrorPadOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_AbsOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::AbsOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SplitVOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SplitVOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UniqueOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::UniqueOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ReverseV2Options: + { + auto ptr = reinterpret_cast<const onert_tflite::ReverseV2Options *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_AddNOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::AddNOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GatherNdOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::GatherNdOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_CosOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::CosOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_WhereOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::WhereOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RankOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::RankOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ReverseSequenceOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ReverseSequenceOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_MatrixDiagOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::MatrixDiagOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_QuantizeOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::QuantizeOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_MatrixSetDiagOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::MatrixSetDiagOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_HardSwishOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::HardSwishOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_IfOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::IfOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_WhileOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::WhileOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DepthToSpaceOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::DepthToSpaceOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NonMaxSuppressionV4Options: + { + auto ptr = reinterpret_cast<const onert_tflite::NonMaxSuppressionV4Options *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NonMaxSuppressionV5Options: + { + auto ptr = reinterpret_cast<const onert_tflite::NonMaxSuppressionV5Options *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ScatterNdOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ScatterNdOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SelectV2Options: + { + auto ptr = reinterpret_cast<const onert_tflite::SelectV2Options *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DensifyOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::DensifyOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SegmentSumOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::SegmentSumOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BatchMatMulOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::BatchMatMulOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_CumsumOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::CumsumOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_CallOnceOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::CallOnceOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BroadcastToOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::BroadcastToOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_Rfft2dOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::Rfft2dOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_Conv3DOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::Conv3DOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_HashtableOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::HashtableOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_HashtableFindOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::HashtableFindOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_HashtableImportOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::HashtableImportOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_HashtableSizeOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::HashtableSizeOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_VarHandleOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::VarHandleOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ReadVariableOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::ReadVariableOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_AssignVariableOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::AssignVariableOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RandomOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::RandomOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BucketizeOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::BucketizeOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GeluOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::GeluOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DynamicUpdateSliceOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::DynamicUpdateSliceOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentProdOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::UnsortedSegmentProdOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentMaxOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::UnsortedSegmentMaxOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentSumOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::UnsortedSegmentSumOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ATan2Options: + { + auto ptr = reinterpret_cast<const onert_tflite::ATan2Options *>(obj); + return verifier.VerifyTable(ptr); + } + default: + return true; + } +} + +inline bool +VerifyBuiltinOptionsVector(::flatbuffers::Verifier &verifier, + const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, + const ::flatbuffers::Vector<uint8_t> *types) +{ + if (!values || !types) + return !values && !types; + if (values->size() != types->size()) + return false; + for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) + { + if (!VerifyBuiltinOptions(verifier, values->Get(i), types->GetEnum<BuiltinOptions>(i))) + { + return false; + } + } + return true; +} + +inline const onert_tflite::Model *GetModel(const void *buf) +{ + return ::flatbuffers::GetRoot<onert_tflite::Model>(buf); +} + +inline const onert_tflite::Model *GetSizePrefixedModel(const void *buf) +{ + return ::flatbuffers::GetSizePrefixedRoot<onert_tflite::Model>(buf); +} + +inline const char *ModelIdentifier() { return "TFL3"; } + +inline bool ModelBufferHasIdentifier(const void *buf) +{ + return ::flatbuffers::BufferHasIdentifier(buf, ModelIdentifier()); +} + +inline bool SizePrefixedModelBufferHasIdentifier(const void *buf) +{ + return ::flatbuffers::BufferHasIdentifier(buf, ModelIdentifier(), true); +} + +inline bool VerifyModelBuffer(::flatbuffers::Verifier &verifier) +{ + return verifier.VerifyBuffer<onert_tflite::Model>(ModelIdentifier()); +} + +inline bool VerifySizePrefixedModelBuffer(::flatbuffers::Verifier &verifier) +{ + return verifier.VerifySizePrefixedBuffer<onert_tflite::Model>(ModelIdentifier()); +} + +inline const char *ModelExtension() { return "tflite"; } + +inline void FinishModelBuffer(::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::Offset<onert_tflite::Model> root) +{ + fbb.Finish(root, ModelIdentifier()); +} + +inline void FinishSizePrefixedModelBuffer(::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::Offset<onert_tflite::Model> root) +{ + fbb.FinishSizePrefixed(root, ModelIdentifier()); +} + +} // namespace onert_tflite + +#endif // FLATBUFFERS_GENERATED_TFLITESCHEMA_ONERT_TFLITE_H_ diff --git a/runtime/onert/core/src/odc/CodegenLoader.cc b/runtime/onert/core/src/odc/CodegenLoader.cc new file mode 100644 index 000000000..764074fe3 --- /dev/null +++ b/runtime/onert/core/src/odc/CodegenLoader.cc @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CodegenLoader.h" + +#include <dlfcn.h> +#include <iostream> +#include <memory> + +static const char *SHARED_LIB_EXT = +#if defined(__APPLE__) && defined(__MACH__) + ".dylib"; +#else + ".so"; +#endif + +namespace onert +{ +namespace odc +{ + +CodegenLoader &CodegenLoader::instance() +{ + static CodegenLoader singleton; + return singleton; +} + +void CodegenLoader::loadLibrary(const char *target) +{ + if (get() != nullptr) + return; + + const std::string codegen_so = "lib" + std::string{target} + SHARED_LIB_EXT; +#ifdef __ANDROID__ + void *handle = dlopen(codegen_so.c_str(), RTLD_LAZY | RTLD_LOCAL); +#else + void *handle = dlmopen(LM_ID_NEWLM, codegen_so.c_str(), RTLD_LAZY | RTLD_LOCAL); +#endif + if (handle == nullptr) + { + throw std::runtime_error("CodegenLoader: " + std::string{dlerror()}); + } + + const auto factory = (factory_t)dlsym(handle, "create_codegen"); + if (factory == nullptr) + { + const std::string dlerror_msg = dlerror(); + dlclose(handle); + throw std::runtime_error("CodegenLoader: " + dlerror_msg); + } + + const auto destroyer = (codegen_destory_t)dlsym(handle, "destroy_codegen"); + _codegen = std::unique_ptr<ICodegen, codegen_destory_t>(factory(), destroyer); + if (_codegen == nullptr) + { + dlclose(handle); + throw std::runtime_error("CodegenLoader: unable to create codegen"); + } + + // Save backend handle (avoid warning by handle lost without dlclose()) + _dlhandle = std::unique_ptr<void, dlhandle_destroy_t>{ + handle, [filename = codegen_so](void *h) { + if (dlclose(h)) + throw std::runtime_error("CodegenLoader: Failed to unload backend " + filename); + }}; +} + +void CodegenLoader::unloadLibrary() +{ + if (get() == nullptr) + return; + + _codegen.reset(nullptr); + _dlhandle.reset(nullptr); +} + +} // namespace odc +} // namespace onert diff --git a/runtime/onert/core/src/odc/CodegenLoader.h b/runtime/onert/core/src/odc/CodegenLoader.h new file mode 100644 index 000000000..397256058 --- /dev/null +++ b/runtime/onert/core/src/odc/CodegenLoader.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_ODC_CODEGEN_LOADER_H__ +#define __ONERT_ODC_CODEGEN_LOADER_H__ + +#include "odc/ICodegen.h" + +#include <functional> +#include <memory> + +namespace onert +{ +namespace odc +{ + +/** + * @brief Class to manage loading and unloading of dynamic library containing + * implementation of ICodegen interface. + */ +class CodegenLoader +{ +public: + /** + * @brief Typedef for function pointer to destroy loaded library handle + */ + using dlhandle_destroy_t = std::function<void(void *)>; + /** + * @brief Typedef for function pointer to create instance of ICodegen + */ + using factory_t = ICodegen *(*)(); + /** + * @brief Typedef for function pointer to destroy instance of ICodegen + */ + using codegen_destory_t = void (*)(ICodegen *); + + /** + * @brief Get singleton instance of CodegenLoader + * @return Reference to singleton instance of CodegenLoader + */ + static CodegenLoader &instance(); + + // delete copy constructor and assignment operator + CodegenLoader(CodegenLoader const &) = delete; + CodegenLoader &operator=(CodegenLoader const &) = delete; + +private: + // cannot create instance of CodegenLoader outside of this class + CodegenLoader() = default; + ~CodegenLoader() = default; + +public: + /** + * @brief Load dynamic library containing implementation of ICodegen + * @param[in] target Target backend name + * This target string will be used to find a backend library. + * The name of target backend library should follow the following rules: + * 'lib' + {backend extension} + '-gen' + {lib extension} + * And the target string should be a name except 'lib' and {lib extension}. + * For example, if the backend extension is 'aaa', the backend library name + * should be 'libaaa-gen.so', and the target string should be 'aaa-gen'. + */ + void loadLibrary(const char *target); + /** + * @brief Unload dynamic library containing implementation of ICodegen + */ + void unloadLibrary(); + /** + * @brief Get instance of ICodegen created through factory method + * @return Pointer to instance of ICodegen + */ + const ICodegen *get() const { return _codegen.get(); } + +private: + // Note: Keep handle to avoid svace warning of "handle lost without dlclose()" + std::unique_ptr<void, dlhandle_destroy_t> _dlhandle; + std::unique_ptr<ICodegen, codegen_destory_t> _codegen{nullptr, nullptr}; +}; + +} // namespace odc +} // namespace onert + +#endif // __ONERT_ODC_CODEGEN_LOADER_H__ diff --git a/runtime/onert/core/src/odc/CodegenManager.cc b/runtime/onert/core/src/odc/CodegenManager.cc new file mode 100644 index 000000000..45f10a69d --- /dev/null +++ b/runtime/onert/core/src/odc/CodegenManager.cc @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CodegenLoader.h" +#include "odc/CodegenManager.h" +#include "util/Utils.h" + +#include <mutex> + +namespace onert +{ +namespace odc +{ + +bool CodegenManager::codegen(const std::string &model_path, const char *target, + CodegenPreference pref) +{ + if (target == nullptr) + throw std::runtime_error("Target string is not set"); + + if (_export_model_path.empty()) + throw std::runtime_error("Export model path is not set"); + + if (model_path.empty()) + throw std::runtime_error("Model path does not exist"); + + // codegen function is thread-unsafe + static std::mutex lock; + std::lock_guard<std::mutex> guard(lock); + + auto &codegen_loader = CodegenLoader::instance(); + codegen_loader.loadLibrary(target); + const auto code_generator = codegen_loader.get(); + // TODO Use compile preference + UNUSED_RELEASE(pref); + const auto result = code_generator->codegen(model_path.c_str(), _export_model_path.c_str()); + codegen_loader.unloadLibrary(); + + return (result == 0); +} + +} // namespace odc +} // namespace onert diff --git a/runtime/onert/core/src/odc/QuantizeManager.cc b/runtime/onert/core/src/odc/QuantizeManager.cc new file mode 100644 index 000000000..fc5725b91 --- /dev/null +++ b/runtime/onert/core/src/odc/QuantizeManager.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizerLoader.h" +#include "odc/QuantizeManager.h" + +#include <iostream> +#include <mutex> + +namespace onert +{ +namespace odc +{ + +bool QuantizeManager::quantize(const std::string &model_path) +{ + if (model_path.empty() || _export_model_path.empty()) + return false; + + // Compile function is thread-unsafe + static std::mutex lock; + std::lock_guard<std::mutex> guard(lock); + + auto &quantize_loader = QuantizerLoader::instance(); + if (quantize_loader.loadLibrary() != 0) + return false; + + auto quantizer = quantize_loader.get(); + auto result = quantizer->quantize(model_path.c_str(), _export_model_path.c_str(), _qtype); + + // TODO Unload quantize library to reduce memory usage + + return (result == 0); +} + +} // namespace odc +} // namespace onert diff --git a/runtime/onert/core/src/odc/QuantizeManager.test.cc b/runtime/onert/core/src/odc/QuantizeManager.test.cc new file mode 100644 index 000000000..3c9f45c6e --- /dev/null +++ b/runtime/onert/core/src/odc/QuantizeManager.test.cc @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "odc/QuantizeManager.h" + +#include <gtest/gtest.h> + +using namespace onert::odc; + +// Test export model path is not set +TEST(odc_QuantizeManager, neg_export_model_path_not_set) +{ + QuantizeManager manager; + manager.quantizeType(ODC_QTYPE_WO_I8_SYM); + ASSERT_EQ(manager.quantize("model_path"), false); +} + +// Test invalid model path +TEST(odc_QuantizeManager, neg_invalid_model_path) +{ + QuantizeManager manager; + manager.exportModelPath("export_model_path.circle"); + manager.quantizeType(ODC_QTYPE_WO_I8_SYM); + ASSERT_EQ(manager.quantize("invalid_model_path.circle"), false); +} diff --git a/runtime/onert/core/src/odc/QuantizerLoader.cc b/runtime/onert/core/src/odc/QuantizerLoader.cc new file mode 100644 index 000000000..8a972e97e --- /dev/null +++ b/runtime/onert/core/src/odc/QuantizerLoader.cc @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizerLoader.h" + +#include <dlfcn.h> +#include <iostream> +#include <string> + +static const char *SHARED_LIB_EXT = +#if defined(__APPLE__) && defined(__MACH__) + ".dylib"; +#else + ".so"; +#endif + +namespace onert +{ +namespace odc +{ + +QuantizerLoader &QuantizerLoader::instance() +{ + static QuantizerLoader singleton; + return singleton; +} + +int32_t QuantizerLoader::loadLibrary() +{ + if (get() != nullptr) + return 0; + + const std::string quantize_so = std::string("libonert_odc") + SHARED_LIB_EXT; + void *handle = dlopen(quantize_so.c_str(), RTLD_LAZY | RTLD_LOCAL); + auto dlerror_msg = dlerror(); + + if (handle == nullptr) + { + std::cerr << "Failed to load " << quantize_so << std::endl; + std::cerr << dlerror_msg << std::endl; + return 1; + } + + { + const char *factory_name = "create_quantizer"; + auto factory = (factory_t)dlsym(handle, factory_name); + dlerror_msg = dlerror(); + + if (factory == nullptr) + { + std::cerr << "QuantizerLoader: unable to find function " << factory_name << dlerror_msg + << std::endl; + dlclose(handle); + return 1; + } + + auto destroyer = (quantizer_destory_t)dlsym(handle, "destroy_quantizer"); + _quantizer = std::unique_ptr<IQuantizer, quantizer_destory_t>(factory(), destroyer); + + if (_quantizer == nullptr) + { + std::cerr << "QuantizerLoader: unable to create quantizer" << std::endl; + dlclose(handle); + return 1; + } + } + + // Save quantize library handle (avoid warning by handle lost without dlclose()) + // clang-format off + _dlhandle = std::unique_ptr<void, dlhandle_destroy_t>{handle, [filename = quantize_so](void *h) { + if (dlclose(h) != 0) + std::cerr << "Failed to unload backend " << filename << std::endl; + }}; + // clang-format on + + return 0; +} + +int32_t QuantizerLoader::unloadLibrary() +{ + if (get() == nullptr) + return 0; + + _quantizer.reset(nullptr); + _dlhandle.reset(nullptr); + + return 0; +} + +} // namespace odc +} // namespace onert diff --git a/runtime/onert/core/src/odc/QuantizerLoader.h b/runtime/onert/core/src/odc/QuantizerLoader.h new file mode 100644 index 000000000..36a9f2996 --- /dev/null +++ b/runtime/onert/core/src/odc/QuantizerLoader.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_ODC_QUANTIZER_LOADER_H__ +#define __ONERT_ODC_QUANTIZER_LOADER_H__ + +#include "odc/IQuantizer.h" + +#include <functional> +#include <memory> + +namespace onert +{ +namespace odc +{ + +/** + * @brief Class to manage loading and unloading of dynamic library containing + * implementation of IQuantizer interface + */ +class QuantizerLoader +{ +public: + /** + * @brief Typedef for function pointer to destroy loaded library handle + */ + using dlhandle_destroy_t = std::function<void(void *)>; + /** + * @brief Typedef for function pointer to create instance of IQuantizer + */ + using factory_t = IQuantizer *(*)(); + /** + * @brief Typedef for function pointer to destroy instance of IQuantizer + */ + using quantizer_destory_t = void (*)(IQuantizer *); + + /** + * @brief Get singleton instance of QuantizerLoader + * @return Reference to singleton instance of QuantizerLoader + */ + static QuantizerLoader &instance(); + +private: + // Cannot create instance of QuantizerLoader outside of this class + QuantizerLoader() = default; + QuantizerLoader(QuantizerLoader const &) = delete; + QuantizerLoader &operator=(QuantizerLoader const &) = delete; + ~QuantizerLoader() = default; + +public: + /** + * @brief Load dynamic library containing implementation of IQuantizer + * @return 0 if success, otherwise errno value + */ + int32_t loadLibrary(); + /** + * @brief Unload dynamic library containing implementation of IQuantizer + * @return 0 if success, otherwise errno value + */ + int32_t unloadLibrary(); + /** + * @brief Get instance of IQuantizer created through factory method + * @return Pointer to instance of IQuantizer + */ + IQuantizer *get() const { return _quantizer.get(); } + +private: + // Note: Keep handle to avoid svace warning of "handle lost without dlclose()" + std::unique_ptr<void, dlhandle_destroy_t> _dlhandle; + std::unique_ptr<IQuantizer, quantizer_destory_t> _quantizer{nullptr, nullptr}; +}; + +} // namespace odc +} // namespace onert + +#endif // __ONERT_ODC_QUANTIZER_LOADER_H__ diff --git a/runtime/onert/core/src/odc/QuantizerLoader.test.cc b/runtime/onert/core/src/odc/QuantizerLoader.test.cc new file mode 100644 index 000000000..112e65b27 --- /dev/null +++ b/runtime/onert/core/src/odc/QuantizerLoader.test.cc @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizerLoader.h" + +#include <gtest/gtest.h> + +using namespace onert::odc; + +// Test QuantizerLoader singleton +TEST(odc_QuantizerLoader, singleton) +{ + QuantizerLoader &loader1 = QuantizerLoader::instance(); + QuantizerLoader &loader2 = QuantizerLoader::instance(); + ASSERT_EQ(&loader1, &loader2); +} + +// Test load quantizer library +TEST(odc_QuantizerLoader, load) +{ + QuantizerLoader &loader = QuantizerLoader::instance(); + // Unload because it may be loaded on previous tests + ASSERT_EQ(loader.unloadLibrary(), 0); + + if (loader.loadLibrary() == 0) + { + // Load twice to check if it is thread-safe + ASSERT_EQ(loader.loadLibrary(), 0); + } +} + +// Get quantizer function without loading quantizer library +TEST(odc_QuantizerLoader, neg_get) +{ + QuantizerLoader &loader = QuantizerLoader::instance(); + // Unload because it may be loaded on previous tests + ASSERT_EQ(loader.unloadLibrary(), 0); + ASSERT_EQ(loader.get(), nullptr); +} + +// Check quantizer function pointer when QuantizerLoader is unloaded +TEST(odc_QuantizerLoader, neg_unload) +{ + QuantizerLoader &loader = QuantizerLoader::instance(); + if (loader.loadLibrary() == 0) + ASSERT_NE(loader.get(), nullptr); + + ASSERT_EQ(loader.unloadLibrary(), 0); + ASSERT_EQ(loader.get(), nullptr); +} diff --git a/runtime/onert/core/src/util/ChromeTracingEventWriter.cc b/runtime/onert/core/src/util/ChromeTracingEventWriter.cc new file mode 100644 index 000000000..c3f5179df --- /dev/null +++ b/runtime/onert/core/src/util/ChromeTracingEventWriter.cc @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "EventWriter.h" + +#include <cassert> +#include <sstream> +#include <utility> +#include <vector> + +// json type for ChromeTracingWriter +namespace +{ + +std::string quote(const std::string &value) +{ + std::stringstream ss; + ss << '"' << value << '"'; + return ss.str(); +} + +std::string field(const std::string &k, const std::string &v) +{ + std::stringstream ss; + ss << quote(k) << " : " << quote(v); + return ss.str(); +} + +struct Content // One Entry in Chrome Event Trace +{ + std::vector<std::pair<std::string, std::string>> flds; + std::vector<std::pair<std::string, std::string>> args; +}; + +std::string object(const Content &content) +{ + std::stringstream ss; + + ss << "{ "; + + ss << field(content.flds[0].first, content.flds[0].second); + + for (uint32_t n = 1; n < content.flds.size(); ++n) + { + ss << ", " << field(content.flds.at(n).first, content.flds.at(n).second); + } + + if (content.args.size() > 0) + { + ss << ", " << quote("args") << " : { "; + ss << field(content.args.at(0).first, content.args.at(0).second); + + for (uint32_t n = 1; n < content.args.size(); ++n) + { + ss << ", " << field(content.args.at(n).first, content.args.at(n).second); + } + + ss << "}"; + } + + ss << " }"; + + return ss.str(); +} + +void fill(Content &content, const DurationEvent &evt, const std::string &name, + const std::string &tid) +{ + content.flds.emplace_back("name", name); + content.flds.emplace_back("pid", "0"); + content.flds.emplace_back("tid", tid); + content.flds.emplace_back("ph", evt.ph); + content.flds.emplace_back("ts", evt.ts); + content.args = evt.args; +} + +void fill(Content &content, const CounterEvent &evt) +{ + assert(evt.name != ""); + + content.flds.emplace_back("name", evt.name); + content.flds.emplace_back("pid", "0"); + content.flds.emplace_back("tid", evt.tid); + content.flds.emplace_back("ph", evt.ph); + content.flds.emplace_back("ts", evt.ts); + content.args = evt.args; +} + +std::string object(const DurationEvent &evt, const std::string &name, const std::string &tid) +{ + Content content; + + fill(content, evt, name, tid); + + return ::object(content); +} + +std::string object(const CounterEvent &evt) +{ + Content content; + + fill(content, evt); + + for (auto it = evt.values.begin(); it != evt.values.end(); ++it) + { + content.args.emplace_back(it->first, it->second); + } + + return ::object(content); +} + +std::string getSessionLabel(const DurationEvent &evt) +{ + return "$" + std::to_string(evt.session_index) + " sess"; +} + +std::string getSubgLabel(const DurationEvent &evt) +{ + return "$" + std::to_string(evt.subg_index) + " subg"; +} + +std::string getOpLabel(const OpSeqDurationEvent &evt) +{ + return "@" + std::to_string(evt.op_index) + " " + evt.op_name; +} + +std::string getLabel(const DurationEvent &evt) +{ + if (auto evt_ptr = dynamic_cast<const OpSeqDurationEvent *>(&evt)) + { + return getOpLabel(*evt_ptr); + } + else // SubgDurationEvent + { + return getSubgLabel(evt); + } +} + +std::string getTid(const DurationEvent &evt) +{ + if (auto evt_ptr = dynamic_cast<const OpSeqDurationEvent *>(&evt)) + { + return getSessionLabel(*evt_ptr) + ", " + getSubgLabel(*evt_ptr) + ", " + evt_ptr->backend; + } + else // SubgDurationEvent + { + return getSessionLabel(evt) + ", " + getSubgLabel(evt); + } +} + +} // namespace + +void ChromeTracingWriter::flush(const std::vector<std::unique_ptr<EventRecorder>> &recorders) +{ + _os << "{\n"; + _os << " " << quote("traceEvents") << ": [\n"; + + for (const auto &recorder : recorders) + { + flushOneRecord(*recorder); + } + + _os << " { }\n"; + _os << " ]\n"; + _os << "}\n"; +} + +void ChromeTracingWriter::flushOneRecord(const EventRecorder &recorder) +{ + for (const auto &evt : recorder.duration_events()) + { + const std::string name = getLabel(*evt); + const std::string tid = getTid(*evt); + + _os << " " << object(*evt, name, tid) << ",\n"; + } + + for (const auto &evt : recorder.counter_events()) + { + _os << " " << object(evt) << ",\n"; + } +} diff --git a/runtime/onert/core/src/util/ConfigSource.cc b/runtime/onert/core/src/util/ConfigSource.cc index 45cce662e..b7fcefc7a 100644 --- a/runtime/onert/core/src/util/ConfigSource.cc +++ b/runtime/onert/core/src/util/ConfigSource.cc @@ -15,13 +15,15 @@ */ #include "util/ConfigSource.h" -#include "util/GeneralConfigSource.h" -#include "util/EnvConfigSource.h" +#include "util/logging.h" + +#include <misc/EnvConfigSource.h> +#include <misc/GeneralConfigSource.h> +#include <misc/IConfigSource.h> -#include <array> #include <algorithm> +#include <array> #include <cassert> - #include <memory> namespace onert @@ -29,9 +31,26 @@ namespace onert namespace util { +using namespace nnfw::misc; + static std::unique_ptr<IConfigSource> _source; +static std::unique_ptr<IConfigSource> _source_ext; void config_source(std::unique_ptr<IConfigSource> &&source) { _source = std::move(source); } +void config_source_ext(std::unique_ptr<IConfigSource> &&source) { _source_ext = std::move(source); } + +void setConfigKeyValues(const CfgKeyValues &keyValues) +{ + auto configsrc = std::make_unique<GeneralConfigSource>(); + + for (auto it = keyValues.begin(); it != keyValues.end(); ++it) + { + VERBOSE(NNPKG_CONFIGS) << "(" << it->first << ") = (" << it->second << ")" << std::endl; + configsrc->set(it->first, it->second); + } + + onert::util::config_source_ext(std::move(configsrc)); +} static IConfigSource *config_source() { @@ -67,6 +86,15 @@ static std::string getConfigOrDefault(const std::string &key) auto ret = config_source()->get(key); if (ret.empty()) { + // if env is not set, search from external + if (_source_ext.get()) + { + ret = _source_ext.get()->get(key); + } + } + // if not found search from defaults + if (ret.empty()) + { auto itr = defaults.find(key); if (itr != defaults.end()) { diff --git a/runtime/onert/core/src/util/EventCollector.cc b/runtime/onert/core/src/util/EventCollector.cc index de37276bf..c1b9c4315 100644 --- a/runtime/onert/core/src/util/EventCollector.cc +++ b/runtime/onert/core/src/util/EventCollector.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "util/EventCollector.h" +#include "EventCollector.h" // C++ standard libraries #include <chrono> @@ -30,24 +30,62 @@ std::string timestamp(void) { auto now = std::chrono::steady_clock::now(); return std::to_string( - std::chrono::duration_cast<std::chrono::microseconds>(now.time_since_epoch()).count()); + std::chrono::duration_cast<std::chrono::microseconds>(now.time_since_epoch()).count()); } -class DurationEventBuilder +class DurationEventBuilder : public EventCollector::EventVisitor { public: DurationEventBuilder(const std::string &ts) : _ts{ts} {} - DurationEvent build(const std::string &tid, const std::string &name, const std::string &ph) const + std::unique_ptr<SubgDurationEvent> build(const EventCollector::SubgEvent &evt_collected, + const std::string &ph) const { - DurationEvent evt; + auto dur_evt = std::make_unique<SubgDurationEvent>(); - evt.name = name; - evt.tid = tid; - evt.ph = ph; - evt.ts = _ts; + // The following will be set by a child of EventsWriter: + // dur_evt.name, dur_evt.tid + dur_evt->ph = ph; + dur_evt->ts = _ts; + dur_evt->tracing_ctx = evt_collected.tracing_ctx; - return evt; + dur_evt->session_index = evt_collected.session_index; + dur_evt->subg_index = evt_collected.subg_index; + + dur_evt->args = evt_collected.userData; + { + dur_evt->args.emplace_back("session", std::to_string(evt_collected.session_index)); + dur_evt->args.emplace_back("subgraph", std::to_string(evt_collected.subg_index)); + } + + return dur_evt; + } + + std::unique_ptr<OpSeqDurationEvent> build(const EventCollector::OpSeqEvent &evt_collected, + const std::string &ph) const + { + auto dur_evt = std::make_unique<OpSeqDurationEvent>(); + + // The following will be set by a child of EventsWriter: + // dur_evt.name, dur_evt.tid + dur_evt->ph = ph; + dur_evt->ts = _ts; + dur_evt->tracing_ctx = evt_collected.tracing_ctx; + + dur_evt->session_index = evt_collected.session_index; + dur_evt->subg_index = evt_collected.subg_index; + + dur_evt->backend = evt_collected.backend; + dur_evt->op_index = evt_collected.op_index; + dur_evt->op_name = evt_collected.op_name; + + dur_evt->args = evt_collected.userData; + { + dur_evt->args.emplace_back("session", std::to_string(evt_collected.session_index)); + dur_evt->args.emplace_back("subgraph", std::to_string(evt_collected.subg_index)); + } + + return dur_evt; } private: @@ -86,19 +124,26 @@ inline void emit_rusage(EventRecorder *rec, const std::string &ts) } // namespace -void EventCollector::onEvent(const Event &event) +template <typename EventT> void EventCollector::onEvent(const EventT &event) { auto ts = timestamp(); + DurationEventBuilder builder(ts); + switch (event.edge) { case Edge::BEGIN: - _rec->emit(DurationEventBuilder(ts).build(event.backend, event.label, "B")); + { + auto duration_evt = builder.build(event, "B"); + _rec->emit(std::move(duration_evt)); break; - + } case Edge::END: - _rec->emit(DurationEventBuilder(ts).build(event.backend, event.label, "E")); + { + auto duration_evt = builder.build(event, "E"); + _rec->emit(std::move(duration_evt)); break; + } } // TODO: Add resurece measurement(e.g. RSS) @@ -107,3 +152,7 @@ void EventCollector::onEvent(const Event &event) emit_rusage(_rec, ts); #endif } + +// template instantiation +template void EventCollector::onEvent<EventCollector::SubgEvent>(const SubgEvent &event); +template void EventCollector::onEvent<EventCollector::OpSeqEvent>(const OpSeqEvent &event); diff --git a/runtime/onert/core/src/util/EventCollector.h b/runtime/onert/core/src/util/EventCollector.h index 8154be592..effb72373 100644 --- a/runtime/onert/core/src/util/EventCollector.h +++ b/runtime/onert/core/src/util/EventCollector.h @@ -17,7 +17,13 @@ #ifndef __ONERT_UTIL_EVENT_COLLECTOR_H__ #define __ONERT_UTIL_EVENT_COLLECTOR_H__ -#include "util/EventRecorder.h" +#include "EventRecorder.h" + +#include "util/TracingCtx.h" + +#include <string> +#include <utility> +#include <vector> class EventCollector { @@ -28,11 +34,69 @@ public: END }; + struct SubgEvent; + struct OpEvent; + + class EventVisitor + { + public: + virtual ~EventVisitor() = default; + + virtual std::unique_ptr<DurationEvent> visit(const SubgEvent &, const std::string &) const + { + throw std::runtime_error("Please implement"); + } + virtual std::unique_ptr<DurationEvent> visit(const OpEvent &, const std::string &) const + { + throw std::runtime_error("Please implement"); + } + }; + struct Event { + const onert::util::TracingCtx *tracing_ctx; + Edge edge; + uint32_t session_index; + uint32_t subg_index; + + // user-defined data: pairs of (key, value) + std::vector<std::pair<std::string, std::string>> userData; + + protected: + Event(const onert::util::TracingCtx *a_tracing_ctx, Edge a_edge, uint32_t a_subg_index) + : tracing_ctx(a_tracing_ctx), edge(a_edge), session_index(tracing_ctx->getSessionId()), + subg_index(a_subg_index) + { /* empty */ + } + + virtual ~Event() = default; + }; + + struct SubgEvent : public Event + { + // constructor for subgraph start and end event + SubgEvent(const onert::util::TracingCtx *a_tracing_ctx, Edge a_edge, uint32_t a_subg_index) + : Event(a_tracing_ctx, a_edge, a_subg_index) + { /* empty */ + } + }; + + // TODO Rename this to OperationEvent + struct OpSeqEvent : public Event + { std::string backend; - std::string label; + uint32_t op_index; + std::string op_name; + + OpSeqEvent(const onert::util::TracingCtx *a_tracing_ctx, Edge a_edge, uint32_t a_subg_index, + const std::string a_backend, uint32_t a_op_index, const std::string a_op_name) + : Event(a_tracing_ctx, a_edge, a_subg_index) + { + backend.assign(a_backend); + op_index = a_op_index; + op_name.assign(a_op_name); + } }; public: @@ -42,7 +106,7 @@ public: } public: - void onEvent(const Event &event); + template <typename EventT> void onEvent(const EventT &event); protected: EventRecorder *_rec; diff --git a/runtime/onert/core/src/util/EventCollectorGlobal.cc b/runtime/onert/core/src/util/EventCollectorGlobal.cc deleted file mode 100644 index d09b95210..000000000 --- a/runtime/onert/core/src/util/EventCollectorGlobal.cc +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "util/EventCollectorGlobal.h" - -#include <cassert> -#include <fstream> -#include <iostream> - -#include "util/ConfigSource.h" - -namespace onert -{ -namespace util -{ - -EventCollectorGlobal::EventCollectorGlobal() : _recorder{}, _collector{&_recorder} -{ - // DO NOTHING -} - -EventCollectorGlobal::~EventCollectorGlobal() -{ - if (!_recorder.empty()) - { - try - { - // TODO Need better way for saved file path than the hardcoded path - std::ofstream ofs{"trace.global.json"}; - _recorder.writeToFile(ofs); - } - catch (const std::exception &e) - { - std::cerr << "E: Fail to record event in EventCollectorGlobal: " << e.what() << std::endl; - } - } -} - -EventCollectorGlobal &EventCollectorGlobal::get() -{ - static EventCollectorGlobal instance; - return instance; -} - -EventDurationBlock::EventDurationBlock(const std::string &tag) : _tag{tag} -{ - auto &glob = EventCollectorGlobal::get(); - glob.collector().onEvent(EventCollector::Event{EventCollector::Edge::BEGIN, "0", _tag}); -} -EventDurationBlock::~EventDurationBlock() -{ - auto &glob = EventCollectorGlobal::get(); - glob.collector().onEvent(EventCollector::Event{EventCollector::Edge::END, "0", _tag}); -} - -EventDurationManual::EventDurationManual(const std::string &tag) : _tag{tag}, _pair{true} {} - -EventDurationManual::~EventDurationManual() -{ - // Check if it has called begin-end pair - assert(_pair); -} - -void EventDurationManual::begin() -{ - _pair = false; - auto &glob = EventCollectorGlobal::get(); - glob.collector().onEvent(EventCollector::Event{EventCollector::Edge::BEGIN, "0", _tag}); -} - -void EventDurationManual::end() -{ - assert(!_pair); - _pair = true; - auto &glob = EventCollectorGlobal::get(); - glob.collector().onEvent(EventCollector::Event{EventCollector::Edge::END, "0", _tag}); -} - -} // namespace util -} // namespace onert diff --git a/runtime/onert/core/src/util/EventCollectorGlobal.h b/runtime/onert/core/src/util/EventCollectorGlobal.h deleted file mode 100644 index 1027ec84d..000000000 --- a/runtime/onert/core/src/util/EventCollectorGlobal.h +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __ONERT_UTIL_EVENT_COLLECTOR_GLOBAL_H__ -#define __ONERT_UTIL_EVENT_COLLECTOR_GLOBAL_H__ - -#include "util/EventRecorder.h" -#include "util/EventCollector.h" - -namespace onert -{ -namespace util -{ - -/** - * @brief Singleton class for event collection from anywhere in code - * - */ -class EventCollectorGlobal -{ -public: - /** - * @brief Get the singleton object of this class - * - * @return EventCollectorGlobal& Singleton object - */ - static EventCollectorGlobal &get(); - -public: - /** - * @brief Getter for event collector object - * - * @return EventCollector& Collector object - */ - EventCollector &collector() { return _collector; } - -private: - EventCollectorGlobal(); - ~EventCollectorGlobal(); - -private: - EventRecorder _recorder; - EventCollector _collector; -}; - -/** - * @brief Helper class for emitting duration event which is handled automatically with ctor/dtor - * - */ -class EventDurationBlock -{ -public: - /** - * @brief Raise a duration event with type of BEGIN - * - * @param tag A label for the duration event - */ - EventDurationBlock(const std::string &tag); - /** - * @brief Raise a duration event with type of END - * - */ - ~EventDurationBlock(); - -private: - std::string _tag; -}; - -/** - * @brief Helper class for emitting duration event which is handled manually - * - * Usage: - * { - * ... - * EventDurationManual duration("some tag"); - * duration.begin(); - * ... - * ... // Code for duration - * ... - * duration.end(); - * } - * - */ -class EventDurationManual -{ -public: - /** - * @brief Construct a new Event Duration Manual object - * - * @param tag A label for the duration object - */ - EventDurationManual(const std::string &tag); - /** - * @brief Destroy the Event Duration Manual object - * - */ - ~EventDurationManual(); - - /** - * @brief Raise a duration event with type of BEGIN - * - */ - void begin(); - /** - * @brief Raise a duration event with type of END - * - */ - void end(); - -private: - std::string _tag; - bool _pair; -}; - -} // namespace util -} // namespace onert - -/** - * Helper Macro Definitions - * - * HOW TO USE - * - * void f(args) - * { - * EVENT_DURATION_FUNCTION(); - * ... - * if(cond) - * { - * EVENT_DURATION_REGION("if branch"); - * ... - * } - * ... - * } - */ - -#define EVENT_DURATION_FUNCTION() \ - ::onert::util::EventDurationBlock __event_duration__##__LINE__ { __FUNCTION__ } - -#define EVENT_DURATION_REGION(tag) \ - ::onert::util::EventDurationBlock __event_duration__##__LINE__ { tag } - -#endif // __ONERT_UTIL_EVENT_COLLECTOR_GLOBAL_H__ diff --git a/runtime/onert/core/src/util/EventRecorder.cc b/runtime/onert/core/src/util/EventRecorder.cc index 13a599bed..85a588d38 100644 --- a/runtime/onert/core/src/util/EventRecorder.cc +++ b/runtime/onert/core/src/util/EventRecorder.cc @@ -14,396 +14,13 @@ * limitations under the License. */ -#include "util/EventRecorder.h" +#include "EventRecorder.h" -#include <sstream> -#include <vector> -#include <unordered_map> -#include <json/json.h> -#include <assert.h> -#include <utility> -#include <map> -#include <set> -#include <stdint.h> - -// json type for Chrome Event Trace -namespace -{ - -std::string quote(const std::string &value) -{ - std::stringstream ss; - ss << '"' << value << '"'; - return ss.str(); -} - -std::string field(const std::string &k, const std::string &v) -{ - std::stringstream ss; - ss << quote(k) << " : " << quote(v); - return ss.str(); -} - -struct Content // One Entry in Chrome Event Trace -{ - std::vector<std::pair<std::string, std::string>> flds; - std::vector<std::pair<std::string, std::string>> args; -}; - -std::string object(const Content &content) -{ - std::stringstream ss; - - ss << "{ "; - - ss << field(content.flds[0].first, content.flds[0].second); - - for (uint32_t n = 1; n < content.flds.size(); ++n) - { - ss << ", " << field(content.flds.at(n).first, content.flds.at(n).second); - } - - if (content.args.size() > 0) - { - ss << ", " << quote("args") << " : { "; - ss << field(content.args.at(0).first, content.args.at(0).second); - - for (uint32_t n = 1; n < content.args.size(); ++n) - { - ss << ", " << field(content.args.at(n).first, content.args.at(n).second); - } - - ss << "}"; - } - - ss << " }"; - - return ss.str(); -} - -void fill(Content &content, const Event &evt) -{ - content.flds.emplace_back("name", evt.name); - content.flds.emplace_back("pid", "0"); - content.flds.emplace_back("tid", evt.tid); - content.flds.emplace_back("ph", evt.ph); - content.flds.emplace_back("ts", evt.ts); -} - -std::string object(const DurationEvent &evt) -{ - Content content; - - fill(content, evt); - - return ::object(content); -} - -std::string object(const CounterEvent &evt) -{ - Content content; - - fill(content, evt); - - for (auto it = evt.values.begin(); it != evt.values.end(); ++it) - { - content.args.emplace_back(it->first, it->second); - } - - return ::object(content); -} - -} // namespace - -// md table type -namespace -{ - -void writeMDTableRow(std::ostream &os, const std::vector<std::string> &list) -{ - os << "| "; - for (auto &key : list) - { - os << key << " | "; - } - os << "\n"; -} - -struct MDContent -{ - std::string name; - uint64_t begin_ts; - uint64_t end_ts; - uint32_t min_rss; - uint32_t max_rss; - uint32_t min_page_reclaims; - uint32_t max_page_reclaims; - - MDContent() - : begin_ts(0), end_ts(0), min_rss(UINT32_MAX), max_rss(0), min_page_reclaims(UINT32_MAX), - max_page_reclaims(0) - { - // DO NOTHING - } - - virtual ~MDContent() = default; - - void updateRss(uint32_t rss) - { - if (min_rss == UINT32_MAX) - min_rss = rss; - if (max_rss == 0) - max_rss = rss; - - if (min_rss > rss) - min_rss = rss; - else if (max_rss < rss) - max_rss = rss; - } - - void updateMinflt(uint32_t minflt) - { - if (min_page_reclaims == UINT32_MAX) - min_page_reclaims = minflt; - if (max_page_reclaims == 0) - max_page_reclaims = minflt; - - if (min_page_reclaims > minflt) - min_page_reclaims = minflt; - else if (max_page_reclaims < minflt) - max_page_reclaims = minflt; - } - - virtual void write(std::ostream &os) const = 0; -}; - -struct OpSeq : public MDContent -{ - std::string backend; - uint64_t graph_latency; - - struct OpSeqCmp - { - bool operator()(const OpSeq &lhs, const OpSeq &rhs) const - { - return lhs.begin_ts < rhs.begin_ts; - } - bool operator()(const OpSeq &lhs, const OpSeq &rhs) { return lhs.begin_ts < rhs.begin_ts; } - bool operator()(OpSeq &lhs, OpSeq &rhs) { return lhs.begin_ts < rhs.begin_ts; } - }; - - void write(std::ostream &os) const override - { - uint64_t opseq_latency = end_ts - begin_ts; - double opseq_per = static_cast<double>(opseq_latency) / graph_latency * 100.0; - writeMDTableRow(os, {name, backend, std::to_string(opseq_latency), std::to_string(opseq_per), - std::to_string(min_rss), std::to_string(max_rss), - std::to_string(min_page_reclaims), std::to_string(max_page_reclaims)}); - } -}; - -struct Graph : public MDContent -{ - std::set<OpSeq, OpSeq::OpSeqCmp> opseqs; - - void setOpSeqs(const std::map<std::string, OpSeq> &name_to_opseq) - { - uint64_t graph_latency = end_ts - begin_ts; - for (auto it : name_to_opseq) - { - auto opseq = it.second; - opseq.graph_latency = graph_latency; - - opseqs.insert(opseq); - - updateRss(opseq.min_rss); - updateRss(opseq.max_rss); - updateMinflt(opseq.min_page_reclaims); - updateMinflt(opseq.max_page_reclaims); - } - } - - void write(std::ostream &os) const override - { - static std::vector<std::string> graph_headers{"latency(us)", "rss_min(kb)", "rss_max(kb)", - "page_reclaims_min", "page_reclaims_max"}; - - static std::vector<std::string> graph_headers_line{"-----------", "-------", "-------", - "-----------------", "-----------------"}; - - // Graph's Header - writeMDTableRow(os, graph_headers); - writeMDTableRow(os, graph_headers_line); - - // Graph's contents - writeMDTableRow(os, {std::to_string(end_ts - begin_ts), std::to_string(min_rss), - std::to_string(max_rss), std::to_string(min_page_reclaims), - std::to_string(max_page_reclaims)}); - - os << "\n"; - - static std::vector<std::string> opseq_headers{ - "OpSeq name", "backend", "latency(us)", "latency(%)", - "rss_min(kb)", "rss_max(kb)", "page_reclaims_min", "page_reclaims_max"}; - - static std::vector<std::string> opseq_headers_line{ - "----------", "-------", "-----------", "-----------", - "-------", "-------", "-----------------", "-----------------"}; - - os << "## OpSequences \n"; - - // OpSeq's Header - writeMDTableRow(os, opseq_headers); - writeMDTableRow(os, opseq_headers_line); - - // OpSeq's contents - for (auto opseq : opseqs) - { - opseq.write(os); - } - - os << "\n"; - } -}; - -struct MDTableBuilder -{ - MDTableBuilder(const std::vector<DurationEvent> &duration_events, - const std::vector<CounterEvent> &counter_events) - : _duration_events(duration_events), _counter_events(counter_events) - { - for (const auto &evt : _counter_events) - { - uint64_t ts = std::stoull(evt.ts); - auto &name = evt.name; - assert(name.compare("maxrss") == 0 || name.compare("minflt") == 0); - assert(evt.values.size() == 1); - auto &val = evt.values.begin()->second; - if (_ts_to_values.find(ts) == _ts_to_values.end()) - { - std::pair<uint32_t, uint32_t> values; - if (name.compare("maxrss") == 0) - values.first = std::stoul(val); - else - values.second = std::stoul(val); - _ts_to_values.insert({ts, values}); - } - else - { - auto &values = _ts_to_values.at(ts); - if (name.compare("maxrss") == 0) - values.first = std::stoul(val); - else - values.second = std::stoul(val); - } - } - } - - MDTableBuilder &build() - { - for (auto &it : divideGraph()) - { - size_t begin_idx = it.first; - size_t end_idx = it.second; - std::map<std::string, OpSeq> name_to_opseq; - for (size_t i = begin_idx + 1; i < end_idx; ++i) - { - const auto &evt = _duration_events[i]; - assert(evt.name.compare("Graph") != 0); - assert(evt.ph.compare("B") == 0 || evt.ph.compare("E") == 0); - if (evt.ph.compare("B") == 0) - { - assert(name_to_opseq.find(evt.name) == name_to_opseq.end()); - name_to_opseq.insert({evt.name, makeOpSeq(evt)}); - } - else - { - assert(name_to_opseq.find(evt.name) != name_to_opseq.end()); - auto &opseq = name_to_opseq.at(evt.name); - updateOpSeq(opseq, evt); - } - } - - _graphs.emplace_back(makeGraph(begin_idx, end_idx, name_to_opseq)); - } - - return *this; - } - - std::vector<std::pair<size_t, size_t>> divideGraph() - { - std::vector<std::pair<size_t, size_t>> graph_idx_list; // pair<begin_idx, end_idx> - for (size_t i = 0, begin_idx = 0; i < _duration_events.size(); ++i) - { - const auto &evt = _duration_events.at(i); - if (evt.name.compare("Graph") == 0) - { - if (evt.ph.compare("B") == 0) - begin_idx = i; - else - graph_idx_list.emplace_back(begin_idx, i); - } - } - return graph_idx_list; - } - - OpSeq makeOpSeq(const DurationEvent &evt) - { - OpSeq opseq; - opseq.name = evt.name; - opseq.begin_ts = std::stoull(evt.ts); - opseq.updateRss(_ts_to_values.at(opseq.begin_ts).first); - opseq.updateMinflt(_ts_to_values.at(opseq.begin_ts).second); - opseq.backend = evt.tid; - return opseq; - } - - void updateOpSeq(OpSeq &opseq, const DurationEvent &evt) - { - opseq.end_ts = std::stoull(evt.ts); - opseq.updateRss(_ts_to_values.at(opseq.end_ts).first); - opseq.updateMinflt(_ts_to_values.at(opseq.end_ts).second); - } - - Graph makeGraph(size_t begin_idx, size_t end_idx, - const std::map<std::string, OpSeq> &name_to_opseq) - { - Graph graph; - graph.name = "Graph"; - graph.begin_ts = std::stoull(_duration_events[begin_idx].ts); - graph.updateRss(_ts_to_values.at(graph.begin_ts).first); - graph.updateMinflt(_ts_to_values.at(graph.begin_ts).second); - graph.end_ts = std::stoull(_duration_events[end_idx].ts); - graph.updateRss(_ts_to_values.at(graph.end_ts).first); - graph.updateMinflt(_ts_to_values.at(graph.end_ts).second); - graph.setOpSeqs(name_to_opseq); - return graph; - } - - void write(std::ostream &os) - { - // Write contents - for (size_t i = 0; i < _graphs.size(); ++i) - { - os << "# Graph " << i << "\n"; - _graphs.at(i).write(os); - } - } - - const std::vector<DurationEvent> &_duration_events; - const std::vector<CounterEvent> &_counter_events; - // timestamp to std::pair<maxrss, minflt> - std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> _ts_to_values; - std::vector<Graph> _graphs; -}; - -} // namespace - -void EventRecorder::emit(const DurationEvent &evt) +void EventRecorder::emit(std::unique_ptr<DurationEvent> &&evt) { std::lock_guard<std::mutex> lock{_mu}; - _duration_events.push_back(evt); + _duration_events.push_back(std::move(evt)); } void EventRecorder::emit(const CounterEvent &evt) @@ -412,146 +29,3 @@ void EventRecorder::emit(const CounterEvent &evt) _counter_events.push_back(evt); } - -void EventRecorder::writeToFile(std::ostream &os) -{ - std::lock_guard<std::mutex> lock{_mu}; - - switch (_write_format) - { - case WriteFormat::CHROME_TRACING: - writeChromeTrace(os); - break; - case WriteFormat::SNPE_BENCHMARK: - writeSNPEBenchmark(os); - break; - case WriteFormat::MD_TABLE: - writeMDTable(os); - break; - default: - assert(!"Invalid value"); - break; - } -} - -void EventRecorder::writeSNPEBenchmark(std::ostream &os) -{ - Json::Value root; - auto &exec_data = root["Execution_Data"] = Json::Value{Json::objectValue}; - - struct Stat - { - uint64_t sum = 0; - uint64_t count = 0; - uint64_t max = 0; - uint64_t min = std::numeric_limits<uint64_t>::max(); - - void accumulate(uint64_t val) - { - sum += val; - count++; - max = std::max(max, val); - min = std::min(min, val); - } - }; - - // Memory - { - std::unordered_map<std::string, Stat> mem_stats; - for (auto &evt : _counter_events) - { - auto &mem_stat = mem_stats[evt.name]; - uint64_t val = std::stoull(evt.values["value"]); - mem_stat.accumulate(val); - } - - auto &mem = exec_data["memory"] = Json::Value{Json::objectValue}; - for (auto &kv : mem_stats) - { - auto &key = kv.first; - auto &val = kv.second; - mem[key]["Avg_Size"] = val.sum / val.count; - mem[key]["Max_Size"] = val.max; - mem[key]["Min_Size"] = val.min; - mem[key]["Runtime"] = "NA"; - } - } - - // Operation Execution Time - { - // NOTE This assumes _duration_events is sorted by "ts" ascending - - // 2D keys : stats[tid][name] - std::unordered_map<std::string, std::unordered_map<std::string, Stat>> stats; - std::unordered_map<std::string, std::unordered_map<std::string, uint64_t>> begin_timestamps; - for (auto &evt : _duration_events) - { - auto &stat = stats[evt.tid][evt.name]; - auto &begin_ts = begin_timestamps[evt.tid][evt.name]; - uint64_t timestamp = std::stoull(evt.ts); - if (evt.ph == "B") - { - if (begin_ts != 0) - throw std::runtime_error{"Invalid Data"}; - begin_ts = timestamp; - } - else if (evt.ph == "E") - { - if (begin_ts == 0 || timestamp < begin_ts) - throw std::runtime_error{"Invalid Data"}; - stat.accumulate(timestamp - begin_ts); - begin_ts = 0; - } - else - throw std::runtime_error{"Invalid Data - invalid value for \"ph\" : \"" + evt.ph + "\""}; - } - - for (auto &kv : begin_timestamps) - for (auto &kv2 : kv.second) - if (kv2.second != 0) - throw std::runtime_error{"Invalid Data - B and E pair does not match."}; - - for (auto &kv : stats) - { - auto &tid = kv.first; - auto &map = kv.second; - auto &json_tid = exec_data[tid] = Json::Value{Json::objectValue}; - for (auto &kv : map) - { - auto &name = kv.first; - auto &val = kv.second; - json_tid[name]["Avg_Time"] = val.sum / val.count; - json_tid[name]["Max_Time"] = val.max; - json_tid[name]["Min_Time"] = val.min; - json_tid[name]["Runtime"] = tid; - } - } - } - - os << root; -} - -void EventRecorder::writeChromeTrace(std::ostream &os) -{ - os << "{\n"; - os << " " << quote("traceEvents") << ": [\n"; - - for (auto &evt : _duration_events) - { - os << " " << object(evt) << ",\n"; - } - - for (auto &evt : _counter_events) - { - os << " " << object(evt) << ",\n"; - } - - os << " { }\n"; - os << " ]\n"; - os << "}\n"; -} - -void EventRecorder::writeMDTable(std::ostream &os) -{ - MDTableBuilder(_duration_events, _counter_events).build().write(os); -} diff --git a/runtime/onert/core/src/util/EventRecorder.h b/runtime/onert/core/src/util/EventRecorder.h index 37ec1a0f1..5cf03d8ac 100644 --- a/runtime/onert/core/src/util/EventRecorder.h +++ b/runtime/onert/core/src/util/EventRecorder.h @@ -17,28 +17,52 @@ #ifndef __ONERT_UTIL_EVENT_RECORDER_H__ #define __ONERT_UTIL_EVENT_RECORDER_H__ +#include "util/TracingCtx.h" + #include <map> #include <memory> #include <mutex> -#include <ostream> #include <vector> +// refer to https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit# struct Event { - std::string name; - std::string tid; - std::string ph; /* REQUIRED */ - std::string ts; /* REQUIRED */ + const onert::util::TracingCtx *tracing_ctx; + + std::string ph; // Event type. + std::string ts; // tracing clock of timestamp of this event + std::vector<std::pair<std::string, std::string>> args; // user-defined data: pairs of (key, value) + + virtual ~Event() = default; }; struct DurationEvent : public Event { - // TO BE FILLED + uint32_t session_index = 0; + uint32_t subg_index = 0; + +protected: + DurationEvent() = default; +}; + +struct SubgDurationEvent : public DurationEvent +{ /* same with DurationEvent */ +}; + +// TODO Rename it to OperationDurationEvent +struct OpSeqDurationEvent : public DurationEvent +{ + // Note: DurationEvent's name and tid will be set by EventWriter + std::string backend; + uint32_t op_index; + std::string op_name; }; struct CounterEvent : public Event { + std::string name; // name of event + std::string tid; // thread ID std::map<std::string, std::string> values; }; @@ -50,35 +74,22 @@ struct CounterEvent : public Event class EventRecorder { public: - enum class WriteFormat - { - CHROME_TRACING, - SNPE_BENCHMARK, - MD_TABLE, - }; - -public: EventRecorder() = default; public: - void emit(const DurationEvent &evt); + void emit(std::unique_ptr<DurationEvent> &&evt); void emit(const CounterEvent &evt); public: - bool empty() { return _duration_events.empty() && _counter_events.empty(); } - void writeToFile(std::ostream &os); - void setWriteFormat(WriteFormat write_format) { _write_format = write_format; } - -private: - void writeSNPEBenchmark(std::ostream &os); - void writeChromeTrace(std::ostream &os); - void writeMDTable(std::ostream &os); + const std::vector<std::unique_ptr<DurationEvent>> &duration_events() const + { + return _duration_events; + } + const std::vector<CounterEvent> &counter_events() const { return _counter_events; } private: std::mutex _mu; - // TODO: Allow user to control write_format - WriteFormat _write_format{WriteFormat::SNPE_BENCHMARK}; - std::vector<DurationEvent> _duration_events; + std::vector<std::unique_ptr<DurationEvent>> _duration_events; std::vector<CounterEvent> _counter_events; }; diff --git a/runtime/onert/core/src/util/EventWriter.cc b/runtime/onert/core/src/util/EventWriter.cc new file mode 100644 index 000000000..ca4bd302e --- /dev/null +++ b/runtime/onert/core/src/util/EventWriter.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "EventWriter.h" + +#include <cassert> + +// initialization +std::mutex EventWriter::_mutex; + +void EventWriter::readyToFlush(std::unique_ptr<EventRecorder> &&recorder) +{ + { + std::unique_lock<std::mutex> lock{_mutex}; + + _recorders.emplace_back(std::move(recorder)); + + if (--_ref_count > 0) + return; + } + // The caller of this method is the last instance that uses EventWriter. + // Let's write log files. + + // Note. According to an internal issue, let snpe json as just file name not '.snpe.json' + flush(WriteFormat::SNPE_BENCHMARK); + flush(WriteFormat::CHROME_TRACING); + flush(WriteFormat::MD_TABLE); +} + +void EventWriter::flush(WriteFormat write_format) +{ + auto *writer = _actual_writers[write_format].get(); + assert(writer); + + writer->flush(_recorders); +} diff --git a/runtime/onert/core/src/util/EventWriter.h b/runtime/onert/core/src/util/EventWriter.h new file mode 100644 index 000000000..672820aa9 --- /dev/null +++ b/runtime/onert/core/src/util/EventWriter.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_UTIL_EVENT_WRITER_H__ +#define __ONERT_UTIL_EVENT_WRITER_H__ + +#include "EventRecorder.h" + +#include <string> +#include <vector> +#include <unordered_map> +#include <mutex> +#include <fstream> + +class EventFormatWriter +{ +public: + EventFormatWriter(const std::string &filepath) : _os{filepath, std::ofstream::out} {} + virtual ~EventFormatWriter() + { /* empty */ + } + + virtual void flush(const std::vector<std::unique_ptr<EventRecorder>> &) = 0; + +protected: + std::ofstream _os; +}; + +class SNPEWriter : public EventFormatWriter +{ +public: + SNPEWriter(const std::string &filepath) : EventFormatWriter(filepath) + { /* empty */ + } + ~SNPEWriter() {} + + void flush(const std::vector<std::unique_ptr<EventRecorder>> &) override; +}; + +class ChromeTracingWriter : public EventFormatWriter +{ +public: + ChromeTracingWriter(const std::string &filepath) : EventFormatWriter(filepath) + { /* empty */ + } + ~ChromeTracingWriter() {} + + void flush(const std::vector<std::unique_ptr<EventRecorder>> &) override; + +private: + void flushOneRecord(const EventRecorder &); +}; + +class MDTableWriter : public EventFormatWriter +{ +public: + MDTableWriter(const std::string &filepath) : EventFormatWriter(filepath) + { /* empty */ + } + ~MDTableWriter() {} + + void flush(const std::vector<std::unique_ptr<EventRecorder>> &) override; +}; + +#include <mutex> + +class EventWriter +{ +public: + enum class WriteFormat + { + CHROME_TRACING, + SNPE_BENCHMARK, + MD_TABLE, + }; + + /** + * @brief Retuens a singleton object + */ + static EventWriter *get(const std::string &workspace_dir) + { + std::unique_lock<std::mutex> lock{_mutex}; + + static EventWriter singleton(workspace_dir); + return &singleton; + } + + /** + * @brief Call this when observer which use EventWriter starts + */ + void startToUse() + { + std::unique_lock<std::mutex> lock{_mutex}; + _ref_count++; + } + + /** + * @brief Call this when observer which use EventWriter finishes. + * After multiple observers calls this method, the reference count will eventually be 0. + * Then, EventWriter will write profiling result file. + */ + void readyToFlush(std::unique_ptr<EventRecorder> &&recorder); + +private: + EventWriter(const std::string &workspace_dir) : _ref_count(0) + { + std::string snpe_log_name(workspace_dir + "/trace.json"); + std::string chrome_tracing_log_name(workspace_dir + "/trace.chrome.json"); + std::string md_table_log_name(workspace_dir + "/trace.table.md"); + + _actual_writers[WriteFormat::SNPE_BENCHMARK] = std::make_unique<SNPEWriter>(snpe_log_name); + _actual_writers[WriteFormat::CHROME_TRACING] = + std::make_unique<ChromeTracingWriter>(chrome_tracing_log_name); + _actual_writers[WriteFormat::MD_TABLE] = std::make_unique<MDTableWriter>(md_table_log_name); + }; + + void flush(WriteFormat write_format); + +private: + static std::mutex _mutex; + + // number of observer of an executor that want to write profiling data + int32_t _ref_count; + + // one recorder object per executor + std::vector<std::unique_ptr<EventRecorder>> _recorders; + + std::unordered_map<WriteFormat, std::unique_ptr<EventFormatWriter>> _actual_writers; +}; + +#endif // __ONERT_UTIL_EVENT_WRITER_H__ diff --git a/runtime/onert/core/src/util/GeneralConfigSource.cc b/runtime/onert/core/src/util/GeneralConfigSource.cc deleted file mode 100644 index 7d2757e58..000000000 --- a/runtime/onert/core/src/util/GeneralConfigSource.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "util/GeneralConfigSource.h" -#include "util/logging.h" - -namespace onert -{ -namespace util -{ - -std::string GeneralConfigSource::get(const std::string &key) const -{ - auto itr = _map.find(key); - if (itr == _map.end()) - { - return ""; - } - else - { - return itr->second; - } -} - -void GeneralConfigSource::set(const std::string &key, const std::string &val) -{ - VERBOSE(GeneralConfigSource) << key << " : " << val << std::endl; - _map[key] = val; -} - -} // namespace util -} // namespace onert diff --git a/runtime/onert/core/src/util/EnvConfigSource.cc b/runtime/onert/core/src/util/Index.test.cc index 0d25b7353..ff73e5e59 100644 --- a/runtime/onert/core/src/util/EnvConfigSource.cc +++ b/runtime/onert/core/src/util/Index.test.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,27 +14,21 @@ * limitations under the License. */ -#include "util/EnvConfigSource.h" +#include "util/Index.h" -#include <cstdlib> +#include <gtest/gtest.h> -namespace onert -{ -namespace util -{ +using Index = ::onert::util::Index<uint32_t, struct TestTag>; -std::string EnvConfigSource::get(const std::string &key) const +TEST(Index, neg_index_test) { - const char *value = std::getenv(key.c_str()); - if (value != nullptr) - { - return value; - } - else - { - return GeneralConfigSource::get(key); - } -} + Index idx1{1u}; + Index idx2{2u}; + Index idx3{idx1}; -} // namespace util -} // namespace onert + ASSERT_EQ(idx1, 1); + ASSERT_EQ(idx1, 1u); + ASSERT_EQ(idx1.value(), 1u); + ASSERT_NE(idx1, idx2); + ASSERT_EQ(idx1, idx3); +} diff --git a/runtime/onert/core/src/util/MDTableEventWriter.cc b/runtime/onert/core/src/util/MDTableEventWriter.cc new file mode 100644 index 000000000..e7d90eec4 --- /dev/null +++ b/runtime/onert/core/src/util/MDTableEventWriter.cc @@ -0,0 +1,365 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "EventWriter.h" + +#include <cassert> +#include <map> +#include <set> +#include <sstream> +#include <stdint.h> +#include <unordered_map> +#include <utility> +#include <vector> + +// md table type +namespace +{ + +void writeMDTableRow(std::ostream &os, const std::vector<std::string> &list) +{ + os << "| "; + for (const auto &key : list) + { + os << key << " | "; + } + os << "\n"; +} + +struct MDContent +{ + std::string name; + uint64_t begin_ts; + uint64_t end_ts; + uint32_t min_rss; + uint32_t max_rss; + uint32_t min_page_reclaims; + uint32_t max_page_reclaims; + + MDContent() + : begin_ts(0), end_ts(0), min_rss(UINT32_MAX), max_rss(0), min_page_reclaims(UINT32_MAX), + max_page_reclaims(0) + { + // DO NOTHING + } + + virtual ~MDContent() = default; + + void updateRss(uint32_t rss) + { + if (min_rss == UINT32_MAX) + min_rss = rss; + if (max_rss == 0) + max_rss = rss; + + if (min_rss > rss) + min_rss = rss; + else if (max_rss < rss) + max_rss = rss; + } + + void updateMinflt(uint32_t minflt) + { + if (min_page_reclaims == UINT32_MAX) + min_page_reclaims = minflt; + if (max_page_reclaims == 0) + max_page_reclaims = minflt; + + if (min_page_reclaims > minflt) + min_page_reclaims = minflt; + else if (max_page_reclaims < minflt) + max_page_reclaims = minflt; + } + + virtual void write(std::ostream &os) const = 0; +}; + +struct Operation : public MDContent +{ + std::string backend; + uint64_t graph_latency; + + struct OperationCmp + { + bool operator()(const Operation &lhs, const Operation &rhs) const + { + return lhs.begin_ts < rhs.begin_ts; + } + bool operator()(const Operation &lhs, const Operation &rhs) + { + return lhs.begin_ts < rhs.begin_ts; + } + bool operator()(Operation &lhs, Operation &rhs) { return lhs.begin_ts < rhs.begin_ts; } + }; + + void write(std::ostream &os) const override + { + uint64_t op_latency = end_ts - begin_ts; + double op_per = static_cast<double>(op_latency) / graph_latency * 100.0; + writeMDTableRow(os, {name, backend, std::to_string(op_latency), std::to_string(op_per), + std::to_string(min_rss), std::to_string(max_rss), + std::to_string(min_page_reclaims), std::to_string(max_page_reclaims)}); + } +}; + +struct Graph : public MDContent +{ + std::set<Operation, Operation::OperationCmp> ops; + std::string session_index; + std::string subgraph_index; + + void setOperations(const std::map<std::string, Operation> &name_to_op) + { + uint64_t graph_latency = end_ts - begin_ts; + for (auto &&it : name_to_op) + { + auto op = it.second; + op.graph_latency = graph_latency; + + ops.insert(op); + + updateRss(op.min_rss); + updateRss(op.max_rss); + updateMinflt(op.min_page_reclaims); + updateMinflt(op.max_page_reclaims); + } + } + + void write(std::ostream &os) const override + { + static std::vector<std::string> graph_headers{"latency(us)", "rss_min(kb)", "rss_max(kb)", + "page_reclaims_min", "page_reclaims_max"}; + + static std::vector<std::string> graph_headers_line{"-----------", "-------", "-------", + "-----------------", "-----------------"}; + + // Graph's Header + writeMDTableRow(os, graph_headers); + writeMDTableRow(os, graph_headers_line); + + // Graph's contents + writeMDTableRow(os, {std::to_string(end_ts - begin_ts), std::to_string(min_rss), + std::to_string(max_rss), std::to_string(min_page_reclaims), + std::to_string(max_page_reclaims)}); + + os << "\n"; + + static std::vector<std::string> op_headers{ + "Op name", "backend", "latency(us)", "latency(%)", + "rss_min(kb)", "rss_max(kb)", "page_reclaims_min", "page_reclaims_max"}; + + static std::vector<std::string> op_headers_line{ + "-------", "-------", "-----------", "-----------", + "-------", "-------", "-----------------", "-----------------"}; + + os << "## Op \n"; + + // Operation's Header + writeMDTableRow(os, op_headers); + writeMDTableRow(os, op_headers_line); + + // Operation's contents + for (auto &&op : ops) + { + op.write(os); + } + + os << "\n"; + } +}; + +std::string getLabel(const OpSeqDurationEvent &evt) +{ + std::string subg_label("$" + std::to_string(evt.subg_index) + " subgraph"); + std::string op_label("@" + std::to_string(evt.op_index) + " " + evt.op_name); + + return subg_label + " " + op_label; +} + +struct MDTableBuilder +{ + MDTableBuilder(const std::vector<std::unique_ptr<DurationEvent>> &duration_events, + const std::vector<CounterEvent> &counter_events) + : _duration_events(duration_events), _counter_events(counter_events) + { +// when ready with low overhead in release build +#ifdef DEBUG + for (const auto &evt : _counter_events) + { + uint64_t ts = std::stoull(evt.ts); + auto &name = evt.name; + assert(name.compare("maxrss") == 0 || name.compare("minflt") == 0); + assert(evt.values.size() == 1); + auto &val = evt.values.begin()->second; + if (_ts_to_values.find(ts) == _ts_to_values.end()) + { + std::pair<uint32_t, uint32_t> values; + if (name.compare("maxrss") == 0) + values.first = std::stoul(val); + else + values.second = std::stoul(val); + _ts_to_values.insert({ts, values}); + } + else + { + auto &values = _ts_to_values.at(ts); + if (name.compare("maxrss") == 0) + values.first = std::stoul(val); + else + values.second = std::stoul(val); + } + } +#endif + } + + MDTableBuilder &build() + { + for (const auto &it : divideGraph()) + { + size_t begin_idx = it.first; + size_t end_idx = it.second; + std::map<std::string, Operation> name_to_op; + for (size_t i = begin_idx + 1; i < end_idx; ++i) + { + const auto *evt = dynamic_cast<const OpSeqDurationEvent *>(_duration_events[i].get()); + if (evt == nullptr) + continue; + + const std::string evt_name = getLabel(*evt); + assert(evt->ph.compare("B") == 0 || evt->ph.compare("E") == 0); + if (evt->ph.compare("B") == 0) + { + assert(name_to_op.find(evt_name) == name_to_op.end()); + name_to_op.insert({evt_name, makeOperation(*evt)}); + } + else + { + assert(name_to_op.find(evt_name) != name_to_op.end()); + auto &op = name_to_op.at(evt_name); + updateOperation(op, *evt); + } + } + + _graphs.emplace_back(makeGraph(begin_idx, end_idx, name_to_op)); + } + + return *this; + } + + std::vector<std::pair<size_t, size_t>> divideGraph() + { + std::vector<std::pair<size_t, size_t>> graph_idx_list; // pair<begin_idx, end_idx> + for (size_t i = 0, begin_idx = 0; i < _duration_events.size(); ++i) + { + const auto subg_evt = dynamic_cast<const SubgDurationEvent *>(_duration_events.at(i).get()); + if (subg_evt == nullptr) + continue; + + if (subg_evt->ph.compare("B") == 0) + begin_idx = i; + else + graph_idx_list.emplace_back(begin_idx, i); + } + return graph_idx_list; + } + + Operation makeOperation(const OpSeqDurationEvent &evt) + { + Operation op; + const std::string &evt_name = getLabel(evt); + op.name = evt_name; + op.begin_ts = std::stoull(evt.ts); + op.backend = evt.backend; +#ifdef DEBUG + op.updateRss(_ts_to_values.at(op.begin_ts).first); + op.updateMinflt(_ts_to_values.at(op.begin_ts).second); +#else + op.updateRss(0); + op.updateMinflt(0); +#endif + return op; + } + + void updateOperation(Operation &op, const DurationEvent &evt) + { + op.end_ts = std::stoull(evt.ts); +#ifdef DEBUG + op.updateRss(_ts_to_values.at(op.end_ts).first); + op.updateMinflt(_ts_to_values.at(op.end_ts).second); +#else + op.updateRss(0); + op.updateMinflt(0); +#endif + } + + Graph makeGraph(size_t begin_idx, size_t end_idx, + const std::map<std::string, Operation> &name_to_op) + { + Graph graph; + graph.name = "Subgraph"; + graph.begin_ts = std::stoull(_duration_events[begin_idx]->ts); + graph.end_ts = std::stoull(_duration_events[end_idx]->ts); + graph.setOperations(name_to_op); + + for (const auto &arg : _duration_events[end_idx]->args) + { + if (arg.first == "session") + graph.session_index = arg.second; + if (arg.first == "subgraph") + graph.subgraph_index = arg.second; + } + +#ifdef DEBUG + graph.updateRss(_ts_to_values.at(graph.begin_ts).first); + graph.updateMinflt(_ts_to_values.at(graph.begin_ts).second); + graph.updateRss(_ts_to_values.at(graph.end_ts).first); + graph.updateMinflt(_ts_to_values.at(graph.end_ts).second); +#else + graph.updateRss(0); + graph.updateMinflt(0); +#endif + return graph; + } + + void write(std::ostream &os) + { + // Write contents + for (size_t i = 0; i < _graphs.size(); ++i) + { + auto &graph = _graphs.at(i); + os << "# Session: " << graph.session_index << ", Subgraph: " << graph.subgraph_index + << ", Running count: " << i << "\n"; + _graphs.at(i).write(os); + } + } + + const std::vector<std::unique_ptr<DurationEvent>> &_duration_events; + const std::vector<CounterEvent> &_counter_events; + + // timestamp to std::pair<maxrss, minflt> + std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> _ts_to_values; + std::vector<Graph> _graphs; +}; + +} // namespace + +void MDTableWriter::flush(const std::vector<std::unique_ptr<EventRecorder>> &records) +{ + for (const auto &recorder : records) + { + MDTableBuilder(recorder->duration_events(), recorder->counter_events()).build().write(_os); + } +} diff --git a/runtime/onert/core/src/util/ObjectManager.test.cc b/runtime/onert/core/src/util/ObjectManager.test.cc new file mode 100644 index 000000000..3fe735732 --- /dev/null +++ b/runtime/onert/core/src/util/ObjectManager.test.cc @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/Index.h" +#include "util/ObjectManager.h" + +#include <gtest/gtest.h> + +using namespace onert; + +struct TestTag; +using Index = typename util::Index<uint32_t, TestTag>; + +TEST(ObjectManager, emplace) +{ + util::ObjectManager<Index, int> man; + + auto index = man.emplace(100); + ASSERT_EQ(man.at(index), 100); +} + +TEST(ObjectManager, neg_remove_1) +{ + util::ObjectManager<Index, int> man; + + Index index = man.emplace(100); + ASSERT_TRUE(man.exist(index)); + ASSERT_EQ(man.at(index), 100); + + man.remove(index); + ASSERT_FALSE(man.exist(index)); +} + +TEST(ObjectManager, neg_remove_2) +{ + util::ObjectManager<Index, int> man; + + auto index0 = man.emplace(100); + auto index1 = man.emplace(200); + ASSERT_TRUE(man.exist(index0)); + ASSERT_EQ(man.at(index0), 100); + ASSERT_TRUE(man.exist(index1)); + ASSERT_EQ(man.at(index1), 200); + + man.remove(index0); + ASSERT_FALSE(man.exist(index0)); + ASSERT_TRUE(man.exist(index1)); + ASSERT_EQ(man.at(index1), 200); +} + +TEST(ObjectManager, push) +{ + util::ObjectManager<Index, int> man; + + // Not specify index + auto index = man.push(std::make_unique<int>(100)); + ASSERT_EQ(man.at(index), 100); + + // Specify index + auto index2 = man.push(std::make_unique<int>(200), Index{33}); + ASSERT_EQ(index2.value(), 33); + ASSERT_EQ(man.at(index2), 200); + + auto index3 = man.push(std::make_unique<int>(300)); + // NOTE auto-generated index number is always (biggest index in the ObjectManager + 1) + ASSERT_EQ(index3.value(), 34); + ASSERT_EQ(man.at(index3), 300); + + auto index4 = man.push(std::make_unique<int>(400), Index{22}); + ASSERT_EQ(index4.value(), 22); + ASSERT_EQ(man.at(index4), 400); + + auto index5 = man.push(std::make_unique<int>(500)); + // NOTE auto-generated index number is always (biggest index in the ObjectManager + 1) + ASSERT_EQ(index5.value(), 35); + ASSERT_EQ(man.at(index5), 500); +} + +TEST(ObjectManager, neg_push) +{ + util::ObjectManager<Index, int> man; + + // Specify index + auto index = man.push(std::make_unique<int>(100), Index{55}); + ASSERT_EQ(index.value(), 55); + ASSERT_EQ(man.at(index), 100); + + // Specify the same index + auto index2 = man.push(std::make_unique<int>(200), Index{55}); + ASSERT_FALSE(index2.valid()); +} + +static const uint32_t kMaxUInt32 = std::numeric_limits<uint32_t>::max(); + +TEST(ObjectManager, neg_push_undefined_index) +{ + util::ObjectManager<Index, int> man; + + // Try inserting invalid(undefined) index + auto index = man.push(std::make_unique<int>(100), Index{kMaxUInt32}); + ASSERT_FALSE(index.valid()); + ASSERT_EQ(man.size(), 0); +} + +TEST(ObjectManager, neg_push_max_index) +{ + util::ObjectManager<Index, int> man; + + // Insert an object with maximum valid index + auto index = man.push(std::make_unique<int>(100), Index{kMaxUInt32 - 1}); + ASSERT_EQ(index.value(), kMaxUInt32 - 1); + ASSERT_EQ(man.at(index), 100); + ASSERT_EQ(man.size(), 1); + + // Reached to the final index so next push/emplace must fail + auto index2 = man.push(std::make_unique<int>(200)); + ASSERT_EQ(man.size(), 1); + ASSERT_FALSE(index2.valid()); +} + +TEST(ObjectManager, neg_emplace_max_index) +{ + util::ObjectManager<Index, int> man; + + // Insert an object with maximum valid index + auto index = man.push(std::make_unique<int>(100), Index{kMaxUInt32 - 1}); + ASSERT_EQ(index.value(), kMaxUInt32 - 1); + ASSERT_EQ(man.at(index), 100); + ASSERT_EQ(man.size(), 1); + + // Reached to the final index so next push/emplace must fail + auto index3 = man.emplace(200); + ASSERT_EQ(man.size(), 1); + ASSERT_FALSE(index3.valid()); +} + +TEST(ObjectManager, const_iterate) +{ + util::ObjectManager<Index, int> man; + + auto index0 = man.emplace(100); + auto index1 = man.emplace(200); + auto index2 = man.emplace(300); + + int sum = 0; + man.iterate([&](const Index &index, const int &val) { sum += val; }); + ASSERT_EQ(sum, 600); +} + +TEST(ObjectManager, non_const_iterate) +{ + util::ObjectManager<Index, int> man; + + auto index0 = man.emplace(100); + auto index1 = man.emplace(200); + auto index2 = man.emplace(300); + + man.iterate([&](const Index &index, int &val) { val += 1; }); + ASSERT_EQ(man.at(index0), 101); + ASSERT_EQ(man.at(index1), 201); + ASSERT_EQ(man.at(index2), 301); +} + +TEST(ObjectManager, set) +{ + util::ObjectManager<Index, int> man; + auto index = man.set(Index{1}, std::make_unique<int>(100)); // Insert + ASSERT_EQ(index, Index{1}); + auto index2 = man.set(index, std::make_unique<int>(200)); // Overwrite + ASSERT_EQ(index2, index); + ASSERT_EQ(man.at(index2), 200); +} + +TEST(ObjectManager, neg_set) +{ + auto v = std::make_unique<int>(100); + util::ObjectManager<Index, int> man; + auto index = man.set(Index{}, std::move(v)); // Try set with an invalid index + ASSERT_EQ(index, Index{}); + ASSERT_FALSE(index.valid()); + ASSERT_NE(v, nullptr); // v must be kept when failure +} + +TEST(ObjectManager, getRawPtr) +{ + auto v = std::make_unique<int>(100); + auto v_ptr = v.get(); + util::ObjectManager<Index, int> man; + auto index = man.push(std::move(v)); + ASSERT_EQ(v_ptr, man.getRawPtr(index)); +} + +TEST(ObjectManager, neg_getRawPtr) +{ + util::ObjectManager<Index, int> man; + auto ptr = man.getRawPtr(Index{1}); + ASSERT_EQ(ptr, nullptr); +} diff --git a/runtime/onert/core/src/util/SNPEEventWriter.cc b/runtime/onert/core/src/util/SNPEEventWriter.cc new file mode 100644 index 000000000..87bbfc662 --- /dev/null +++ b/runtime/onert/core/src/util/SNPEEventWriter.cc @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "EventWriter.h" + +#include <json/json.h> + +#include <cassert> +#include <unordered_map> +#include <utility> + +/** + * @brief Version of SNPE format + * In version 1 + * - There is no "version" field in Json + * - Only one subgraph is supported + * - Operation name is a form of "$3 ADD" + * + * In version 2, + * - "version" : "2" was added in Json + * - Multiple session and multiple subgraphs are supported + * - When there is only one session, operation name is a form of "$2 subgraph $3 ADD", + * meaning ADD op whose operation index 3 in a subgraph whose index is 2 + * - When there are two or more sessions, operation name is a form of + * "$1 session $2 subgraph $3 ADD", meaning ADD op whose operation index 3 + * in a subgraph whose index is 2, which was run in 1st session. + */ +#define SNPE_JSON_SCHEMA_VERSION "2" + +namespace +{ + +std::string getLabel(const DurationEvent &evt) +{ + if (auto evt_ptr = dynamic_cast<const OpSeqDurationEvent *>(&evt)) + { + std::string subg_label("$" + std::to_string(evt_ptr->subg_index) + " subgraph"); + std::string op_label("$" + std::to_string(evt_ptr->op_index) + " " + evt_ptr->op_name); + + // Note : At this moment, there is only one thread running for EventWriter + if (evt_ptr->tracing_ctx->hasMultipleSessions()) + { + std::string session_label("$" + std::to_string(evt_ptr->session_index) + " session"); + return session_label + " " + subg_label + " " + op_label; + } + else + { + // When there is only one session, do not include session info + // Refer to https://github.sec.samsung.net/STAR/nnfw/issues/11436#issuecomment-930332 + return subg_label + " " + op_label; + } + } + else // SubgEvent + return "Graph"; +} + +std::string getBackend(const DurationEvent &evt) +{ + if (auto evt_ptr = dynamic_cast<const OpSeqDurationEvent *>(&evt)) + return evt_ptr->backend; + else // SubbEvent + return "runtime"; +} + +} // namespace + +void SNPEWriter::flush(const std::vector<std::unique_ptr<EventRecorder>> &recorders) +{ + struct Stat + { + uint64_t sum = 0; + uint64_t count = 0; + uint64_t max = 0; + uint64_t min = std::numeric_limits<uint64_t>::max(); + + void accumulate(uint64_t val) + { + sum += val; + count++; + max = std::max(max, val); + min = std::min(min, val); + } + }; + + Json::Value root; + root["version"] = SNPE_JSON_SCHEMA_VERSION; + + auto &exec_data = root["Execution_Data"] = Json::Value{Json::objectValue}; + + // Memory + { + std::unordered_map<std::string, Stat> mem_stats; + for (const auto &recorder : recorders) + { + for (const auto &evt : recorder->counter_events()) + { + auto &mem_stat = mem_stats[evt.name]; + uint64_t val = std::stoull(evt.values.at("value")); + mem_stat.accumulate(val); + } + } + + auto &mem = exec_data["memory"] = Json::Value{Json::objectValue}; + for (const auto &kv : mem_stats) + { + auto &key = kv.first; + auto &val = kv.second; + mem[key]["Avg_Size"] = val.sum / val.count; + mem[key]["Max_Size"] = val.max; + mem[key]["Min_Size"] = val.min; + mem[key]["Runtime"] = "NA"; + } + } + + // Operation Execution Time + { + // NOTE This assumes _duration_events is sorted by "ts" ascending + + // 2D keys : stats[tid][name] + std::unordered_map<std::string, std::unordered_map<std::string, Stat>> stats; + std::unordered_map<std::string, std::unordered_map<std::string, uint64_t>> begin_timestamps; + for (const auto &recorder : recorders) + { + for (const auto &evt : recorder->duration_events()) + { + std::string evt_name = getLabel(*evt); + std::string evt_tid = getBackend(*evt); + + auto &stat = stats[evt_tid][evt_name]; + auto &begin_ts = begin_timestamps[evt_tid][evt_name]; + uint64_t timestamp = std::stoull(evt->ts); + if (evt->ph == "B") + { + if (begin_ts != 0) + throw std::runtime_error{"Invalid Data"}; + begin_ts = timestamp; + } + else if (evt->ph == "E") + { + if (begin_ts == 0 || timestamp < begin_ts) + throw std::runtime_error{"Invalid Data"}; + stat.accumulate(timestamp - begin_ts); + begin_ts = 0; + } + else + throw std::runtime_error{"Invalid Data - invalid value for \"ph\" : \"" + evt->ph + "\""}; + } + } + + for (const auto &kv : begin_timestamps) + for (const auto &kv2 : kv.second) + if (kv2.second != 0) + throw std::runtime_error{"Invalid Data - B and E pair does not match."}; + + for (const auto &kv : stats) + { + const auto &tid = kv.first; + const auto &map = kv.second; + auto &json_tid = exec_data[tid] = Json::Value{Json::objectValue}; + for (const auto &kv : map) + { + auto &name = kv.first; + auto &val = kv.second; + json_tid[name]["Avg_Time"] = val.sum / val.count; + json_tid[name]["Max_Time"] = val.max; + json_tid[name]["Min_Time"] = val.min; + json_tid[name]["Runtime"] = tid; + } + } + } + + _os << root; +} diff --git a/runtime/onert/core/src/util/ShapeInference.cc b/runtime/onert/core/src/util/ShapeInference.cc index 95c15049d..2a6fde45b 100644 --- a/runtime/onert/core/src/util/ShapeInference.cc +++ b/runtime/onert/core/src/util/ShapeInference.cc @@ -22,6 +22,7 @@ #include "util/logging.h" #include <cassert> +#include <numeric> #include <sstream> #include <cmath> @@ -72,6 +73,19 @@ ir::Shape broadcastShapes(const ir::Shape &lhs_shape, const ir::Shape &rhs_shape } // namespace +namespace bcq +{ +inline int getOutputSize(const ir::Shape &cluster_shape, const int32_t *cluster_buf) +{ + int size = 0; + for (int idx = 0; idx < cluster_shape.dim(0); idx++) + { + size += cluster_buf[idx * 2 + 1]; + } + return size; +} +} // namespace bcq + // // Shape inference // @@ -97,10 +111,9 @@ std::pair<int, int> calcConvLikeHeightAndWidth(const int in_h, const int in_w, c break; case ir::PaddingType::EXPLICIT: out_h = - (in_h + pad.param.top + pad.param.bottom - effective_filter_h_size) / stride.vertical + 1; + (in_h + pad.param.top + pad.param.bottom - effective_filter_h_size) / stride.vertical + 1; out_w = - (in_w + pad.param.left + pad.param.right - effective_filter_w_size) / stride.horizontal + - 1; + (in_w + pad.param.left + pad.param.right - effective_filter_w_size) / stride.horizontal + 1; break; default: assert(false); @@ -114,8 +127,13 @@ ir::Shape inferEltwiseShape(const ir::Shape &lhs_shape, const ir::Shape &rhs_sha return broadcastShapes(lhs_shape, rhs_shape); } -ir::Shape inferArgMaxShape(const ir::Shape &input_shape, int axis, int rank) +ir::Shape inferArgMinMaxShape(const ir::Shape &input_shape, int axis, int rank) { + if (axis < 0 || axis >= rank) + { + throw std::runtime_error("ArgMinMax shape inference: Wrong axis value " + std::to_string(axis)); + } + ir::Shape out_shape; for (int idx = 0; idx < rank; ++idx) { @@ -167,15 +185,15 @@ ir::Shape inferReduceShape(const ir::Shape &input_shape, const std::vector<int> else { // Calculates size of reducing axis. - int num_reduce_axis = num_axis; for (int i = 0; i < num_axis; ++i) { int current = axes[i]; + if (!(-input_num_dims <= current && current < input_num_dims)) + throw std::runtime_error{"Invalid dim value " + std::to_string(current)}; if (current < 0) { current += input_num_dims; } - assert(0 <= current && current < input_num_dims); for (int j = 0; j < i; ++j) { int previous = axes[j]; @@ -185,14 +203,12 @@ ir::Shape inferReduceShape(const ir::Shape &input_shape, const std::vector<int> } if (current == previous) { - --num_reduce_axis; break; } } } // Determines output dimensions. ir::Shape out_shape; - int num_skip_axis = 0; for (int idx = 0; idx < input_num_dims; ++idx) { bool is_axis = false; @@ -200,7 +216,6 @@ ir::Shape inferReduceShape(const ir::Shape &input_shape, const std::vector<int> { if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx) { - ++num_skip_axis; is_axis = true; break; } @@ -259,19 +274,24 @@ ir::Shape inferBatchMatMulShape(const ir::Shape &lhs_shape, const ir::Shape &rhs return output_shape; } -ir::Shape inferBroadcastToShape(const ir::Shape wshape, const int32_t *shape_buffer) +/* + * shp_shape : SHAPE input tensor's shape + * shp_buf : SHAPE input tensor's buffer + */ +ir::Shape inferBroadcastToShape(const ir::Shape shp_shape, const int32_t *shp_buf) { - const int num_elements = wshape.num_elements(); + + const int num_elements = shp_shape.num_elements(); assert(num_elements != 0); - assert(shape_buffer); + assert(shp_buf); ir::Shape new_shape(num_elements); for (int i = 0; i < num_elements; ++i) { - assert(shape_buffer[i] != 0); // It shouldn't be 0. - new_shape.dim(i) = shape_buffer[i]; + assert(shp_buf[i] != 0); // It shouldn't be 0. + new_shape.dim(i) = shp_buf[i]; } return new_shape; @@ -305,6 +325,9 @@ ir::Shape inferConcatShape(const Shapes &in_shapes, const ir::operation::Concat: ir::Shape inferConv2DShape(const ir::Shape &in_shape, const ir::Shape &ker_shape, const ir::operation::Conv2D::Param ¶m, ir::Layout layout) { + if (param.stride.horizontal == 0 || param.stride.vertical == 0) + throw std::runtime_error{"Conv2D: stride values must be positive"}; + auto ifm_shape = in_shape.asFeature(layout); // Kernel format is [depth_out, kernel_height, kernel_width, depth_in] @@ -321,6 +344,9 @@ ir::Shape inferDepthwiseConv2DShape(const ir::Shape &in_shape, const ir::Shape & const ir::operation::DepthwiseConv2D::Param ¶m, ir::Layout layout) { + if (param.stride.horizontal == 0 || param.stride.vertical == 0) + throw std::runtime_error{"DepthwiseConv2D: stride values must be positive"}; + assert(layout == ir::Layout::NHWC); auto ifm_shape = in_shape.asFeature(layout); @@ -330,7 +356,7 @@ ir::Shape inferDepthwiseConv2DShape(const ir::Shape &in_shape, const ir::Shape & assert(kf_shape.N == 1); const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W, - param.padding, param.stride); + param.padding, param.stride, param.dilation); return ir::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, kf_shape.C}; } @@ -354,18 +380,22 @@ ir::Shape inferExpandDimsShape(const ir::Shape &in_shape, int32_t axis) return out_shape; } -ir::Shape inferFillShape(const ir::Shape &in_shape, const int32_t *buffer) +template <typename T> ir::Shape inferFillShape(const ir::Shape &fill_shape, const T *shape_buf) { - ir::Shape out_shape(in_shape.dim(0)); + ir::Shape out_shape(fill_shape.dim(0)); for (int out_x = 0; out_x < out_shape.rank(); ++out_x) { - out_shape.dim(out_x) = buffer[out_x]; + out_shape.dim(out_x) = static_cast<int32_t>(shape_buf[out_x]); } return out_shape; } +// template instantiation +template ir::Shape inferFillShape(const ir::Shape &fill_shape, const int32_t *shape_buf); +template ir::Shape inferFillShape(const ir::Shape &fill_shape, const int64_t *shape_buf); + ir::Shape inferFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &ker_shape) { assert(in_shape.rank() >= 2); @@ -380,11 +410,60 @@ ir::Shape inferFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &k return {ir::Shape({static_cast<int32_t>(batch_size), num_units})}; } +ir::Shape inferBCQFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &cluster_shape, + const int32_t *cluster_buf) +{ + assert(cluster_shape.rank() == 2); + assert(cluster_shape.dim(1) == 2); + + const auto input_size = in_shape.dim(1); + const auto output_size = bcq::getOutputSize(cluster_shape, cluster_buf); + + return {ir::Shape({output_size, input_size})}; +} + +ir::Shape inferBCQGatherShape(const ir::Shape &indices_shape, const ir::Shape &cluster_shape, + const int32_t *cluster_buf, int rank, + const ir::operation::BCQGather::Param ¶m) +{ + ir::Shape out_shape; + ir::Shape in_original_shape; + + assert(cluster_shape.rank() == 2); + assert(cluster_shape.dim(1) == 2); + + auto hidden_size = param.input_hidden_size; + auto axis = param.axis; + + in_original_shape.append(bcq::getOutputSize(cluster_shape, cluster_buf)); + in_original_shape.append(hidden_size); + + const int indices_rank = indices_shape.rank(); + for (int idx = 0; idx < rank; ++idx) + { + if (idx == (int)axis) + { + for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++) + { + out_shape.append(indices_shape.dim(indices_idx)); + } + } + else + { + out_shape.append(in_original_shape.dim(idx)); + } + } + + return out_shape; +} + ir::Shape inferGatherShape(const ir::Shape &input_shape, const ir::Shape &indices_shape, int axis, int rank) { ir::Shape out_shape; + const int indices_rank = indices_shape.rank(); + for (int idx = 0; idx < rank; ++idx) { if (idx == axis) @@ -470,6 +549,9 @@ ir::Shape inferPadShape(const ir::Shape &in_shape, const int32_t *pad_buf, const ir::Shape inferPoolShape(const ir::Shape &in_shape, const ir::operation::Pool2D::Param ¶m, const ir::Layout layout) { + if (param.stride.horizontal == 0 || param.stride.vertical == 0) + throw std::runtime_error{"Pool2D: stride values must be positive"}; + assert(layout == ir::Layout::NHWC); auto ifm_shape = in_shape.asFeature(layout); const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, param.kh, param.kw, @@ -482,6 +564,17 @@ ir::Shape inferResizeBilinearShape(const ir::Shape &in_shape, const int32_t outp const int32_t output_width) { assert(in_shape.rank() == 4); + if (output_height < 0) + { + throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_height = " + + std::to_string(output_height)}; + } + if (output_width < 0) + { + throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_width = " + + std::to_string(output_width)}; + } + ir::Shape ret(in_shape.rank()); ret.dim(0) = in_shape.dim(0); @@ -497,9 +590,9 @@ template <typename T> ir::Shape inferRangeShape(T start_val, T limit_val, T delt ir::Shape out_shape(static_cast<int>(1)); out_shape.dim(0) = - (std::is_integral<T>::value - ? ((std::abs(start_val - limit_val) + std::abs(delta_val) - 1) / std::abs(delta_val)) - : std::ceil(std::abs((start_val - limit_val) / delta_val))); + (std::is_integral<T>::value + ? ((std::abs(start_val - limit_val) + std::abs(delta_val) - 1) / std::abs(delta_val)) + : std::ceil(std::abs((start_val - limit_val) / delta_val))); return out_shape; } @@ -507,16 +600,17 @@ template <typename T> ir::Shape inferRangeShape(T start_val, T limit_val, T delt template ir::Shape inferRangeShape(int start_val, int limit_val, int delta_val); template ir::Shape inferRangeShape(float start_val, float limit_val, float delta_val); -ir::Shape inferReshapeShape(const int32_t *shape_buf, const int32_t shape_num_elements, - const size_t total_num_elements) +ir::Shape inferReshapeShape(const ir::Shape &input_shape, const int32_t *shape_buf, + const int32_t shape_num_elements) { ir::Shape ret(shape_num_elements); - int32_t flatten_dim = ir::Shape::UNSPECIFIED_DIM; + int32_t flatten_dim = ir::Shape::kUnspecifiedDim; + auto total_num_elements = input_shape.num_elements(); for (int32_t i = 0; i < shape_num_elements; ++i) { if (shape_buf[i] < 0) { - if (flatten_dim != ir::Shape::UNSPECIFIED_DIM) + if (flatten_dim != ir::Shape::kUnspecifiedDim) throw std::runtime_error("Reshape: 2nd param has special dim(for flatten) more than twice"); flatten_dim = i; ret.dim(i) = 1; @@ -526,12 +620,20 @@ ir::Shape inferReshapeShape(const int32_t *shape_buf, const int32_t shape_num_el ret.dim(i) = shape_buf[i]; } } - if (flatten_dim != ir::Shape::UNSPECIFIED_DIM) + if (flatten_dim != ir::Shape::kUnspecifiedDim) ret.dim(flatten_dim) = total_num_elements / ret.num_elements(); // Check reshapable if (total_num_elements != static_cast<size_t>(ret.num_elements())) - throw std::runtime_error("Reshape: 2nd param is not compatible with the shape of input"); + { + // Multi batch case + // TODO Handle multi batch case more precisely on runtime level + if ((ret.dim(0) == 1) && + (total_num_elements == static_cast<size_t>(ret.num_elements() * input_shape.dim(0)))) + ret.dim(0) = input_shape.dim(0); + else + throw std::runtime_error("Reshape: 2nd param is not compatible with the shape of input"); + } return ret; } @@ -566,9 +668,9 @@ ir::Shape inferSelectShape(const ir::Shape &input_cond_shape, const ir::Shape &i ir::Shape true_shape = input_true_shape; ir::Shape false_shape = input_false_shape; int most_rank = - (cond_shape.rank() >= true_shape.rank()) && (cond_shape.rank() >= false_shape.rank()) - ? cond_shape.rank() - : (false_shape.rank() >= true_shape.rank() ? false_shape.rank() : true_shape.rank()); + (cond_shape.rank() >= true_shape.rank()) && (cond_shape.rank() >= false_shape.rank()) + ? cond_shape.rank() + : (false_shape.rank() >= true_shape.rank() ? false_shape.rank() : true_shape.rank()); ir::Shape calculate_shape(most_rank); @@ -579,9 +681,9 @@ ir::Shape inferSelectShape(const ir::Shape &input_cond_shape, const ir::Shape &i for (int i = 0; i < most_rank; ++i) { calculate_shape.dim(i) = - (cond_shape.dim(i) >= true_shape.dim(i)) && (cond_shape.dim(i) >= false_shape.dim(i)) - ? cond_shape.dim(i) - : (false_shape.dim(i) >= true_shape.dim(i) ? false_shape.dim(i) : true_shape.dim(i)); + (cond_shape.dim(i) >= true_shape.dim(i)) && (cond_shape.dim(i) >= false_shape.dim(i)) + ? cond_shape.dim(i) + : (false_shape.dim(i) >= true_shape.dim(i) ? false_shape.dim(i) : true_shape.dim(i)); if ((cond_shape.dim(i) != calculate_shape.dim(i) && cond_shape.dim(i) != 1) || (true_shape.dim(i) != calculate_shape.dim(i) && true_shape.dim(i) != 1) || @@ -613,7 +715,8 @@ ir::Shape inferSelectShape(const ir::Shape &input_cond_shape, const ir::Shape &i return new_shape; } -ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins, const int32_t *sizes) +template <typename T> +ir::Shape inferSliceShape(const ir::Shape &input_shape, const T *begins_buf, const T *sizes_buf) { const uint32_t rank = input_shape.rank(); ir::Shape out_shape(rank); @@ -623,12 +726,12 @@ ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins, c const auto input_dim = input_shape.dim(idx); // begin is zero-based - auto begin = begins[idx]; + auto begin = begins_buf[idx]; if (begin < 0) throw std::runtime_error("shape inference Slice: Invalid begin."); // size is one-based - auto size = sizes[idx]; + auto size = sizes_buf[idx]; if (size < -1) throw std::runtime_error("shape inference Slice: Invalid size."); @@ -638,18 +741,23 @@ ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins, c } else { - if (input_dim < begin + size) + if (input_dim < static_cast<int32_t>(begin + size)) throw std::runtime_error("shape inference Slice: Invalid begin and size."); } - out_shape.dim(idx) = size; + out_shape.dim(idx) = static_cast<int32_t>(size); } return out_shape; } +// template instantiation +template ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins_buf, + const int32_t *sizes_buf); +template ir::Shape inferSliceShape(const ir::Shape &input_shape, const int64_t *begins_buf, + const int64_t *sizes_buf); ir::Shape inferSpaceToBatchNDShape(const ir::Shape &input_shape, const ir::Shape &block_shape_shape, - const ir::Shape &padding_shape, const int32_t *block_shape_data, - const int32_t *padding_data) + const ir::Shape &padding_shape, const int32_t *block_shape_buf, + const int32_t *padding_buf) { const uint32_t rank = input_shape.rank(); ir::Shape out_shape(rank); @@ -677,14 +785,14 @@ ir::Shape inferSpaceToBatchNDShape(const ir::Shape &input_shape, const ir::Shape for (int dim = 0; dim < kSpatialDimensionNum; ++dim) { int final_dim_size = - (input_shape.dim(dim + 1) + padding_data[dim * 2] + padding_data[dim * 2 + 1]); + (input_shape.dim(dim + 1) + padding_buf[dim * 2] + padding_buf[dim * 2 + 1]); - assert(final_dim_size % block_shape_data[dim] == 0); + assert(final_dim_size % block_shape_buf[dim] == 0); - out_shape.dim(dim + 1) = final_dim_size / block_shape_data[dim]; + out_shape.dim(dim + 1) = final_dim_size / block_shape_buf[dim]; } - const int output_batch_size = input_shape.dim(0) * block_shape_data[0] * block_shape_data[1]; + const int output_batch_size = input_shape.dim(0) * block_shape_buf[0] * block_shape_buf[1]; const int output_channel_size = input_shape.dim(3); out_shape.dim(0) = output_batch_size; @@ -740,7 +848,7 @@ ir::Shape inferSqueezeShape(const ir::Shape &in_shape, const ir::operation::Sque if (!(current >= 0 && current < shape_rank && in_shape.dim(current) == 1)) { throw std::runtime_error( - "The following conditions must be met: 0 <= dim < Shape rank, dim == 1"); + "The following conditions must be met: 0 <= dim < Shape rank, dim == 1"); } if (!should_squeeze[current]) @@ -948,35 +1056,71 @@ ir::Shape inferStridedSliceShape(const ir::Shape &input_shape, const StridedSlic return out_shape; } -ir::Shape inferTileShape(const ir::Shape &in_shape, const int32_t *multiplier) +ir::Shape inferTileShape(const ir::Shape &in_shape, const int32_t *multiplier_buf, + const int32_t multiplier_size) { - // assert(in_shape.rank() == multiplier.rank()); + if (multiplier_size != in_shape.rank()) + { + throw std::runtime_error( + "inferTileShape failed, input rank: " + std::to_string(in_shape.rank()) + + ", bad multipliers size: " + std::to_string(multiplier_size) + ""); + } ir::Shape new_Shape(in_shape.rank()); for (int i = 0; i < in_shape.rank(); ++i) { - assert(multiplier[i]); // multiplier[i] shuld not be 0. - new_Shape.dim(i) = in_shape.dim(i) * multiplier[i]; + assert(multiplier_buf[i]); // multiplier_buf[i] shuld not be 0. + new_Shape.dim(i) = in_shape.dim(i) * multiplier_buf[i]; } return new_Shape; } -ir::Shape inferTransposeShape(const ir::Shape &in_shape, const std::vector<int> &perm) +ir::Shape inferTransposeShape(const ir::Shape &in_shape, const int32_t *perm_buf, + const int32_t perm_size) { - if (static_cast<int>(perm.size()) > in_shape.rank()) + const auto rank = in_shape.rank(); + if (perm_size > rank) + { + throw std::runtime_error("inferTransposeShape failed, bad permutation size: " + + std::to_string(perm_size)); + } + + const int32_t *perm_data = perm_buf; + std::vector<int32_t> regular_perm_vec; + if (perm_size == 0) { - throw std::runtime_error("inferTransposeShape failed, bad rank size: " + - std::to_string(static_cast<int>(perm.size()))); + // perm_data will be set to (n-1...0) + regular_perm_vec.resize(rank); + std::iota(regular_perm_vec.begin(), regular_perm_vec.end(), 0); + std::reverse(regular_perm_vec.begin(), regular_perm_vec.end()); + perm_data = regular_perm_vec.data(); } - ir::Shape out_shape(static_cast<int>(perm.size())); - for (int idx = 0; idx < static_cast<int>(perm.size()); idx++) + else + { + assert(rank == perm_size); + } + + ir::Shape out_shape(rank); + std::vector<bool> visit_perms(rank, false); + for (int idx = 0; idx < rank; idx++) { - if (perm[idx] < 0 || perm[idx] >= static_cast<int>(perm.size())) + const auto perm_val = perm_data[idx]; + // Check invalid permutation value + if (perm_val < 0 || perm_val >= rank) { - throw std::runtime_error("inferTransposeShape failed, bad perm value: " + - std::to_string(perm[idx])); + throw std::runtime_error("inferTransposeShape failed, bad permutation value: " + + std::to_string(perm_val)); } - out_shape.dim(idx) = in_shape.dim(perm[idx]); + + // Check duplicated permutation value + if (visit_perms.at(perm_val)) + { + throw std::runtime_error("inferTransposeShape failed, duplicated permutation value: " + + std::to_string(perm_val)); + } + visit_perms.at(perm_val) = true; + + out_shape.dim(idx) = in_shape.dim(perm_val); } return out_shape; } diff --git a/runtime/onert/core/src/util/ShapeInference.test.cc b/runtime/onert/core/src/util/ShapeInference.test.cc new file mode 100644 index 000000000..96579bfa2 --- /dev/null +++ b/runtime/onert/core/src/util/ShapeInference.test.cc @@ -0,0 +1,544 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/ShapeInference.h" + +#include <gtest/gtest.h> + +using namespace onert::ir; + +TEST(ShapeInference, Elementwise) +{ + Shape lhs_shape{1, 299, 299, 3}; + Shape rhs_shape{3}; + auto infered_out_shape = onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape); + + ASSERT_EQ(infered_out_shape.rank(), 4); + ASSERT_EQ(infered_out_shape.dim(0), 1); + ASSERT_EQ(infered_out_shape.dim(1), 299); + ASSERT_EQ(infered_out_shape.dim(2), 299); + ASSERT_EQ(infered_out_shape.dim(3), 3); +} + +TEST(ShapeInference, neg_Elementwise) +{ + Shape lhs_shape{1, 299, 299, 3}; + Shape rhs_shape{5, 3}; + ASSERT_THROW(onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape), std::runtime_error); +} + +TEST(ShapeInference, Pool2DNodeSame) +{ + Shape in_shape{10, 6, 12, 20}; + Stride stride{3, 7}; + Padding padding{PaddingType::SAME}; + + operation::Pool2D::Param avg_pool_param{ + operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE}; + auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param); + + ASSERT_EQ(infered_out_shape.rank(), 4); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20); + + operation::Pool2D::Param max_pool_param{ + operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE}; + infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param); + + ASSERT_EQ(infered_out_shape.rank(), 4); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20); +} + +TEST(ShapeInference, Pool2DNodeValid) +{ + Shape in_shape{10, 6, 12, 20}; + Stride stride{3, 7}; + Padding padding{PaddingType::VALID}; + + operation::Pool2D::Param avg_pool_param{ + operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE}; + auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param); + + ASSERT_EQ(infered_out_shape.rank(), 4); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20); + + operation::Pool2D::Param max_pool_param{ + operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE}; + infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param); + + ASSERT_EQ(infered_out_shape.rank(), 4); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20); +} + +TEST(ShapeInference, Pool2DNodeExplicit) +{ + Shape in_shape{10, 3, 5, 20}; + + Stride stride{3, 7}; + Padding padding{4, 3, 2, 1}; + + operation::Pool2D::Param avg_pool_param{ + operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE}; + auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param); + + ASSERT_EQ(infered_out_shape.rank(), 4); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20); + + operation::Pool2D::Param max_pool_param{ + operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE}; + infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param); + + ASSERT_EQ(infered_out_shape.rank(), 4); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20); +} + +TEST(ShapeInference, neg_Pool2DNode_InvalidStride) +{ + Shape in_shape{10, 6, 12, 20}; + Stride stride{0, 7}; + Padding padding{PaddingType::SAME}; + + operation::Pool2D::Param avg_pool_param{ + operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE}; + ASSERT_THROW(onert::shape_inference::inferPoolShape(in_shape, avg_pool_param), + std::runtime_error); +} + +TEST(ShapeInference, Conv2D) +{ + Shape in_shape{10, 6, 12, 20}; + Shape ker_shape{30, 3, 6, 20}; + + operation::Conv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, Activation::NONE, + Dilation{1, 1}}; + auto infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param); + + ASSERT_EQ(infered_out_shape.rank(), 4); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30); + + param = operation::Conv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, Activation::NONE, + Dilation{1, 1}}; + infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param); + + ASSERT_EQ(infered_out_shape.rank(), 4); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30); + + param = + operation::Conv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, Activation::NONE, Dilation{1, 1}}; + infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param); + + ASSERT_EQ(infered_out_shape.rank(), 4); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30); +} + +TEST(ShapeInference, neg_Conv2D_InvalidStride) +{ + Shape in_shape{10, 6, 12, 20}; + Shape ker_shape{30, 3, 6, 20}; + + operation::Conv2D::Param param{Stride{0, 0}, Padding{PaddingType::VALID}, Activation::NONE, + Dilation{1, 1}}; + ASSERT_THROW(onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param), + std::runtime_error); +} + +TEST(ShapeInference, DepthwiseConv2D) +{ + Shape in_shape{10, 6, 12, 20}; + Shape ker_shape{1, 3, 6, 60}; + + operation::DepthwiseConv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, 3, + Activation::NONE, Dilation{1, 1}}; + auto infered_out_shape = + onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param); + + ASSERT_EQ(infered_out_shape.rank(), 4); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60); + + param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, 3, + Activation::NONE, Dilation{1, 1}}; + infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param); + + ASSERT_EQ(infered_out_shape.rank(), 4); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60); + + param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, 3, Activation::NONE, + Dilation{1, 1}}; + infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param); + + ASSERT_EQ(infered_out_shape.rank(), 4); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2); + ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60); +} + +TEST(ShapeInference, neg_DepthwiseConv2D_InvalidSride) +{ + Shape in_shape{10, 6, 12, 20}; + Shape ker_shape{1, 3, 6, 60}; + + operation::DepthwiseConv2D::Param param{Stride{3, 0}, Padding{PaddingType::VALID}, 3, + Activation::NONE, Dilation{1, 1}}; + ASSERT_THROW(onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param), + std::runtime_error); +} + +TEST(ShapeInference, Concat) +{ + { + Shape in1{10, 20, 30, 3, 50}; + Shape in2{10, 20, 30, 2, 50}; + Shape in3{10, 20, 30, 2, 50}; + + operation::Concat::Param param{3}; + auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2, in3}, param); + + ASSERT_EQ(infered_out_shape.rank(), 5); + ASSERT_EQ(infered_out_shape.dim(0), 10); + ASSERT_EQ(infered_out_shape.dim(1), 20); + ASSERT_EQ(infered_out_shape.dim(2), 30); + ASSERT_EQ(infered_out_shape.dim(3), 7); + ASSERT_EQ(infered_out_shape.dim(4), 50); + } + { + // case 1. when axis < 0 + Shape in1{10, 20, 2}; + Shape in2{10, 20, 3}; + + operation::Concat::Param param{-1}; + auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param); + + ASSERT_EQ(infered_out_shape.rank(), 3); + ASSERT_EQ(infered_out_shape.dim(0), 10); + ASSERT_EQ(infered_out_shape.dim(1), 20); + ASSERT_EQ(infered_out_shape.dim(2), 5); + } + { + // case 2. when axis < 0 + Shape in1{2, 20, 2}; + Shape in2{3, 20, 2}; + + operation::Concat::Param param{-3}; + auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param); + + ASSERT_EQ(infered_out_shape.rank(), 3); + ASSERT_EQ(infered_out_shape.dim(0), 5); + ASSERT_EQ(infered_out_shape.dim(1), 20); + ASSERT_EQ(infered_out_shape.dim(2), 2); + } +} + +TEST(ShapeInference, neg_Concat) +{ + { + operation::Concat::Param param{2}; + Shape in1{10, 1, 3}; + Shape in2{10, 2, 4}; // dim[1] should be 1 but 2 + + EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param)); + } + { // wrong rank + operation::Concat::Param param{2}; + Shape in1{10, 2, 3, 4}; + Shape in2{10, 2, 4}; // rank should be 4 + + EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param)); + } +} + +TEST(ShapeInference, ExpandDims) +{ + Shape in_shape{30, 40}; + + auto check = [&](int32_t axis, Shape &expected) { + auto actual = onert::shape_inference::inferExpandDimsShape(in_shape, axis); + + ASSERT_EQ(actual.rank(), 3); + for (int32_t dim = 0; dim < expected.rank(); dim++) + ASSERT_EQ(actual.dim(dim), expected.dim(dim)); + }; + + { // boundary + int32_t axis = 0; + Shape expected{1, 30, 40}; + check(axis, expected); + } + { // boundary + int32_t axis = 2; + Shape expected{30, 40, 1}; + check(axis, expected); + } + { // inside + int32_t axis = 1; + Shape expected{30, 1, 40}; + check(axis, expected); + } + { // negative boundary + int32_t axis = -1; + Shape expected{30, 40, 1}; + check(axis, expected); + } + { // negative boundary + int32_t axis = -3; + Shape expected{1, 30, 40}; + check(axis, expected); + } +} + +TEST(ShapeInference, neg_ExpandDims) +{ + Shape in_shape{30, 40}; + + { // over boundary + int32_t axis = 3; + ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error); + } + { // over boundary + int32_t axis = -4; + ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error); + } +} + +TEST(ShapeInference, FullyConnected) +{ + Shape in_shape{3, 4, 5, 6}; + Shape ker_shape{3, 10}; + auto infered_out_shape = onert::shape_inference::inferFullyConnectedShape(in_shape, ker_shape); + + ASSERT_EQ(infered_out_shape.rank(), 2); + ASSERT_EQ(infered_out_shape.dim(0), 36); + ASSERT_EQ(infered_out_shape.dim(1), 3); +} + +TEST(ShapeInference, Transpose) +{ + auto check = [&](Shape &in_shape, std::vector<int> perm, Shape &expected) { + // pre-conditions + ASSERT_EQ(in_shape.rank(), perm.size()); + ASSERT_EQ(expected.rank(), perm.size()); + auto inferred_out_shape = + onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size()); + // post-conditions + ASSERT_EQ(inferred_out_shape.rank(), perm.size()); + for (int32_t dim = 0; dim < expected.rank(); dim++) + { + ASSERT_EQ(inferred_out_shape.dim(dim), expected.dim(dim)); + } + }; + // check for 2-D + { + Shape in_shape{2, 3}; + std::vector<int> perm = {1, 0}; + Shape expected{3, 2}; + // int32_t rank = 2; + check(in_shape, perm, expected); + } + // check for 3-D + { + Shape in_shape{1, 2, 3}; + std::vector<int> perm = {2, 0, 1}; + Shape expected{3, 1, 2}; + // int32_t rank = 3; + check(in_shape, perm, expected); + } + // check for 4-D + { + Shape in_shape{1, 2, 3, 4}; + std::vector<int> perm = {1, 3, 0, 2}; + Shape expected{2, 4, 1, 3}; + // int32_t rank = 4; + check(in_shape, perm, expected); + } +} + +TEST(ShapeInference, neg_Transpose) +{ + Shape in_shape{1, 2, 3}; + // Invalid parameter size + { + std::vector<int> perm = {2, 0, 1, 0}; + // int32_t rank = 3; + ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size()), + std::runtime_error); + } + // Invalid parameter value + { + std::vector<int> perm = {2, 0, 3}; + // int32_t rank = 3; + ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size()), + std::runtime_error); + } +} + +TEST(ShapeInference, Gather) +{ + auto check = [&](Shape &input, Shape &indices, Shape &expected, int32_t axis) { + int rank = input.rank(); + auto actual = onert::shape_inference::inferGatherShape(input, indices, axis, rank); + + ASSERT_EQ(actual.rank(), expected.rank()); + + for (int32_t dim = 0; dim < expected.rank(); dim++) + ASSERT_EQ(actual.dim(dim), expected.dim(dim)); + }; + + // check for 2-D, 3-D, axis 0 + { + Shape input{3, 4}; + Shape indices{1, 1, 2}; + int32_t axis = 0; + Shape expected{1, 1, 2, 4}; + check(input, indices, expected, axis); + } + + // check for 2-D, 3-D, axis 1 + { + Shape input{3, 4}; + Shape indices{1, 2, 1}; + int32_t axis = 1; + Shape expected{3, 1, 2, 1}; + check(input, indices, expected, axis); + } + + // check for 3-D, 2-D, axis 0 + { + Shape input{2, 3, 4}; + Shape indices{1, 2}; + int32_t axis = 0; + Shape expected{1, 2, 3, 4}; + check(input, indices, expected, axis); + } + + // check for 3-D, 2-D, axis 2 + { + Shape input{2, 3, 4}; + Shape indices{2, 1}; + int32_t axis = 2; + Shape expected{2, 3, 2, 1}; + check(input, indices, expected, axis); + } + + // check for 4D, axis 0 + { + Shape input{1, 2, 3, 4}; + Shape indices{2}; + int32_t axis = 0; + Shape expected{2, 2, 3, 4}; + check(input, indices, expected, axis); + } +} + +TEST(ShapeInference, BCQFullyConnected) +{ + auto check = [&](Shape &in_shape, Shape &cluster_shape, std::vector<int> cluster, + Shape &expected) { + auto actual = + onert::shape_inference::inferBCQFullyConnectedShape(in_shape, cluster_shape, cluster.data()); + ASSERT_EQ(actual.rank(), expected.rank()); + + for (int32_t dim = 0; dim < expected.rank(); dim++) + ASSERT_EQ(actual.dim(dim), expected.dim(dim)); + }; + + { + Shape in_shape{10, 1}; + Shape cluster_shape{3, 2}; + std::vector<int> cluster = {1, 10, 2, 10, 3, 10}; + + Shape expected{30, 1}; + check(in_shape, cluster_shape, cluster, expected); + } + + { + Shape in_shape{1, 1}; + Shape cluster_shape{1, 2}; + std::vector<int> cluster = {3, 50}; + + Shape expected{50, 1}; + check(in_shape, cluster_shape, cluster, expected); + } +} + +TEST(ShapeInference, BCQGather) +{ + auto check = [&](Shape &indices_shape, Shape &cluster_shape, std::vector<int> cluster, + uint32_t hidden_size, uint32_t axis, int rank, Shape &expected) { + operation::BCQGather::Param param{hidden_size, axis}; + auto actual = onert::shape_inference::inferBCQGatherShape(indices_shape, cluster_shape, + cluster.data(), rank, param); + ASSERT_EQ(actual.rank(), expected.rank()); + + for (int32_t dim = 0; dim < expected.rank(); dim++) + ASSERT_EQ(actual.dim(dim), expected.dim(dim)); + }; + + { + Shape indices_shape{5, 1}; + Shape cluster_shape{3, 2}; + std::vector<int> cluster = {1, 10, 2, 10, 3, 10}; + uint32_t hidden_size = 10; + uint32_t axis = 0; + int rank = 2; + + Shape expected{5, 1, 10}; + check(indices_shape, cluster_shape, cluster, hidden_size, axis, rank, expected); + } + + { + Shape indices_shape{5, 1}; + Shape cluster_shape{3, 2}; + std::vector<int> cluster = {1, 10, 2, 10, 3, 10}; + uint32_t hidden_size = 10; + uint32_t axis = 1; + int rank = 2; + + Shape expected{30, 5, 1}; + check(indices_shape, cluster_shape, cluster, hidden_size, axis, rank, expected); + } +} diff --git a/runtime/onert/core/src/util/TracingCtx.cc b/runtime/onert/core/src/util/TracingCtx.cc new file mode 100644 index 000000000..c05baee60 --- /dev/null +++ b/runtime/onert/core/src/util/TracingCtx.cc @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/TracingCtx.h" + +namespace onert +{ +namespace util +{ + +// initializing static member var +std::mutex TracingCtx::_session_id_mutex; +uint32_t TracingCtx::_next_session_id = 0; + +} // namespace util +} // namespace onert |