diff options
Diffstat (limited to 'compiler/luci')
779 files changed, 30293 insertions, 5708 deletions
diff --git a/compiler/luci/CMakeLists.txt b/compiler/luci/CMakeLists.txt index 214a1bbf2..3771176f0 100644 --- a/compiler/luci/CMakeLists.txt +++ b/compiler/luci/CMakeLists.txt @@ -1,8 +1,11 @@ add_subdirectory(env) add_subdirectory(log) add_subdirectory(lang) +add_subdirectory(testhelper) add_subdirectory(service) add_subdirectory(pass) +add_subdirectory(profile) +add_subdirectory(partition) add_subdirectory(logex) add_subdirectory(import) add_subdirectory(export) diff --git a/compiler/luci/env/include/luci/UserSettings.h b/compiler/luci/env/include/luci/UserSettings.h index bcfd16071..b56bd65e2 100644 --- a/compiler/luci/env/include/luci/UserSettings.h +++ b/compiler/luci/env/include/luci/UserSettings.h @@ -32,6 +32,7 @@ struct UserSettings Undefined, MuteWarnings, DisableValidation, + ProfilingDataGen, }; static UserSettings *settings(); diff --git a/compiler/luci/env/src/UserSettings.cpp b/compiler/luci/env/src/UserSettings.cpp index 27dec762d..b4c661190 100644 --- a/compiler/luci/env/src/UserSettings.cpp +++ b/compiler/luci/env/src/UserSettings.cpp @@ -30,6 +30,7 @@ public: private: bool _MuteWarnings{false}; bool _DisableValidation{false}; + bool _ProfilingDataGen{false}; }; void UserSettingsImpl::set(const Key key, bool value) @@ -42,6 +43,9 @@ void UserSettingsImpl::set(const Key key, bool value) case Key::DisableValidation: _DisableValidation = value; break; + case Key::ProfilingDataGen: + _ProfilingDataGen = value; + break; default: throw std::runtime_error("Invalid key in boolean set"); break; @@ -56,6 +60,8 @@ bool UserSettingsImpl::get(const Key key) const return _MuteWarnings; case Key::DisableValidation: return _DisableValidation; + case Key::ProfilingDataGen: + return _ProfilingDataGen; default: throw std::runtime_error("Invalid key in boolean get"); break; diff --git a/compiler/luci/env/src/UserSettings.test.cpp b/compiler/luci/env/src/UserSettings.test.cpp index 8d9d1875b..899c0c2a1 100644 --- a/compiler/luci/env/src/UserSettings.test.cpp +++ b/compiler/luci/env/src/UserSettings.test.cpp @@ -51,6 +51,18 @@ TEST(UserSettings, DisableValidation) ASSERT_TRUE(settings->get(luci::UserSettings::Key::DisableValidation)); } +TEST(UserSettings, ProfilingDataGen) +{ + auto settings = luci::UserSettings::settings(); + ASSERT_NE(nullptr, settings); + + settings->set(luci::UserSettings::Key::ProfilingDataGen, false); + ASSERT_FALSE(settings->get(luci::UserSettings::Key::ProfilingDataGen)); + + settings->set(luci::UserSettings::Key::ProfilingDataGen, true); + ASSERT_TRUE(settings->get(luci::UserSettings::Key::ProfilingDataGen)); +} + TEST(UserSettings, undefined_set_NEG) { auto settings = luci::UserSettings::settings(); diff --git a/compiler/luci/export/CMakeLists.txt b/compiler/luci/export/CMakeLists.txt index fe4382ecd..01f737110 100644 --- a/compiler/luci/export/CMakeLists.txt +++ b/compiler/luci/export/CMakeLists.txt @@ -13,6 +13,7 @@ target_link_libraries(luci_export PRIVATE mio_circle) target_link_libraries(luci_export PRIVATE luci_env) target_link_libraries(luci_export PRIVATE luci_log) target_link_libraries(luci_export PRIVATE luci_logex) +target_link_libraries(luci_export PRIVATE luci_profile) target_link_libraries(luci_export PRIVATE nncc_common) target_link_libraries(luci_export PRIVATE locop) target_link_libraries(luci_export PRIVATE oops) diff --git a/compiler/luci/export/include/luci/CircleFileExpContract.h b/compiler/luci/export/include/luci/CircleFileExpContract.h index eeaf2d9bb..8ef1b5e0c 100644 --- a/compiler/luci/export/include/luci/CircleFileExpContract.h +++ b/compiler/luci/export/include/luci/CircleFileExpContract.h @@ -33,7 +33,7 @@ struct CircleFileExpContract : public luci::CircleExporter::Contract { public: CircleFileExpContract(luci::Module *module, const std::string &filename) - : _module(module), _filepath(filename) + : _module(module), _filepath(filename) { // NOTHING TO DO } diff --git a/compiler/luci/export/src/CircleExportMetadata.cpp b/compiler/luci/export/src/CircleExportMetadata.cpp new file mode 100644 index 000000000..ef905a882 --- /dev/null +++ b/compiler/luci/export/src/CircleExportMetadata.cpp @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleExportMetadata.h" + +#include <luci/UserSettings.h> + +namespace +{ + +void write_u32(std::vector<uint8_t> &to, uint32_t value) +{ + to.emplace_back(0xFF & (value >> 0 * 8)); + to.emplace_back(0xFF & (value >> 1 * 8)); + to.emplace_back(0xFF & (value >> 2 * 8)); + to.emplace_back(0xFF & (value >> 3 * 8)); +} + +flatbuffers::Offset<circle::Metadata> metadata_offset(flatbuffers::FlatBufferBuilder &builder, + luci::SerializedModelData &md, + const std::vector<uint8_t> &data, + const std::string &metadata_name) +{ + auto buffer_id = static_cast<uint32_t>(md._buffers.size()); + md._buffers.push_back(circle::CreateBufferDirect(builder, &data)); + return circle::CreateMetadataDirect(builder, metadata_name.c_str(), buffer_id); +} + +} // namespace + +namespace luci +{ + +// 'source_table' is encoded to binary format. +const std::vector<uint8_t> CircleExportMetadata::encoded_source_table(void) +{ + std::vector<uint8_t> data; + + write_u32(data, _source_table.size()); + + for (auto &kv : _source_table) + { + const auto id = kv.first; + write_u32(data, id); + + const auto origin_name = kv.second; + const auto length = origin_name.length(); + write_u32(data, length + 1); // name + '\0 + + for (uint32_t i = 0; i < length; ++i) + { + data.emplace_back(origin_name.at(i)); + } + data.emplace_back('\0'); + } + + return data; +} + +// 'op_table' is encoded to binary format. +const std::vector<uint8_t> CircleExportMetadata::encoded_op_table(void) +{ + std::vector<uint8_t> data; + + write_u32(data, _op_table.size()); + + for (auto &kv : _op_table) + { + const auto id = kv.first; + write_u32(data, id); + + const auto origins = kv.second; + const auto node_num = origins.size(); + write_u32(data, node_num); + + for (auto origin : origins) + { + write_u32(data, origin); + } + } + + return data; +} + +} // namespace luci + +namespace luci +{ + +std::vector<flatbuffers::Offset<circle::Metadata>> +createCircleMetadataVector(flatbuffers::FlatBufferBuilder &builder, luci::SerializedModelData &md) +{ + std::vector<flatbuffers::Offset<circle::Metadata>> metadata_vec; + + auto settings = luci::UserSettings::settings(); + if (settings->get(luci::UserSettings::Key::ProfilingDataGen)) + { + metadata_vec.emplace_back( + metadata_offset(builder, md, md._metadata.encoded_source_table(), "ONE_source_table")); + + metadata_vec.emplace_back( + metadata_offset(builder, md, md._metadata.encoded_op_table(), "ONE_op_table")); + } + + return metadata_vec; +} + +} // namespace luci diff --git a/compiler/luci/export/src/CircleExportMetadata.h b/compiler/luci/export/src/CircleExportMetadata.h new file mode 100644 index 000000000..10cda421e --- /dev/null +++ b/compiler/luci/export/src/CircleExportMetadata.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_CIRCLE_EXPORT_METADATA_H__ +#define __LUCI_CIRCLE_EXPORT_METADATA_H__ + +#include "SerializedData.h" + +#include <flatbuffers/flatbuffers.h> +#include <mio/circle/schema_generated.h> + +namespace luci +{ + +/** + * @brief Create Metadata corresponding to model metadata + */ +std::vector<flatbuffers::Offset<circle::Metadata>> +createCircleMetadataVector(flatbuffers::FlatBufferBuilder &builder, SerializedModelData &md); + +} // namespace luci + +#endif // __LUCI_CIRCLE_EXPORT_METADATA_H__ diff --git a/compiler/luci/export/src/CircleExporterImpl.cpp b/compiler/luci/export/src/CircleExporterImpl.cpp index df7542797..7e218191c 100644 --- a/compiler/luci/export/src/CircleExporterImpl.cpp +++ b/compiler/luci/export/src/CircleExporterImpl.cpp @@ -16,10 +16,13 @@ #include "CircleExporterImpl.h" #include "Optimize.h" +#include "CircleExportMetadata.h" #include "CircleTensorExporter.h" #include "CircleOperationExporter.h" #include "CircleExporterUtils.h" +#include <luci/IR/CircleNodes.h> + #include <oops/InternalExn.h> #include <mio/circle/schema_generated.h> #include <flatbuffers/flatbuffers.h> @@ -27,46 +30,16 @@ #include <cassert> #include <unordered_map> #include <string> -#include <stdexcept> +#include <vector> namespace { -luci::CircleInput *input_node(loco::Graph *g, const loco::GraphInputIndex &index) -{ - for (uint32_t n = 0; n < g->nodes()->size(); ++n) - { - if (auto input = dynamic_cast<luci::CircleInput *>(g->nodes()->at(n))) - { - if (input->indexed() && input->index() == index) - { - return input; - } - } - } - return nullptr; -} - -luci::CircleOutput *output_node(loco::Graph *g, const loco::GraphOutputIndex &index) -{ - for (uint32_t n = 0; n < g->nodes()->size(); ++n) - { - if (auto output = dynamic_cast<luci::CircleOutput *>(g->nodes()->at(n))) - { - if (output->indexed() && output->index() == index) - { - return output; - } - } - } - return nullptr; -} - void registerGraphInputTensors(loco::Graph *graph, luci::SubGraphContext &ctx) { for (uint32_t n = 0; n < graph->inputs()->size(); ++n) { - auto node = input_node(graph, n); + auto node = luci::input_node(graph, n); assert(node != nullptr); ctx._inputs.push_back(luci::get_tensor_index(node)); } @@ -76,7 +49,7 @@ void registerGraphOutputTensors(loco::Graph *graph, luci::SubGraphContext &ctx) { for (uint32_t n = 0; n < graph->outputs()->size(); ++n) { - auto push = output_node(graph, n); + auto push = luci::output_node(graph, n); assert(push != nullptr); auto node = push->from(); assert(node != nullptr); @@ -113,7 +86,7 @@ encodeOperatorCodes(FlatBufferBuilder &builder, std::unordered_map<luci::OpCode, else { operator_codes_vec[idx] = - CreateOperatorCode(builder, it.first.opcode, builder.CreateString(it.first.custom_code)); + CreateOperatorCode(builder, it.first.opcode, builder.CreateString(it.first.custom_code)); } } @@ -186,16 +159,16 @@ void CircleExporterImpl::exportGraph(loco::Graph *graph) std::string description_str = "nnpackage"; auto description = _builder.CreateString(description_str); + // Metadata + auto metadata_vec = createCircleMetadataVector(_builder, md); + auto metadata = _builder.CreateVector(std::vector<Offset<Metadata>>(metadata_vec)); + // create array of buffers auto buffers = _builder.CreateVector(md._buffers); - // empty metadata - std::vector<int> metadata_buffer_vec; - auto metadata_buffer = _builder.CreateVector(metadata_buffer_vec); - // Model auto model_offset = CreateModel(_builder, version, operator_codes, subgraphs, description, - buffers, metadata_buffer); + buffers, 0 /* metadata_buffer */, metadata); FinishModelBuffer(_builder, model_offset); } @@ -250,19 +223,19 @@ void CircleExporterImpl::exportModule(Module *module) std::string description_str = "nnpackage"; auto description = _builder.CreateString(description_str); + // Metadata + auto metadata_vec = createCircleMetadataVector(_builder, md); + auto metadata = _builder.CreateVector(std::vector<Offset<Metadata>>(metadata_vec)); + // create array of buffers auto buffers = _builder.CreateVector(md._buffers); - // empty metadata - std::vector<int> metadata_buffer_vec; - auto metadata_buffer = _builder.CreateVector(metadata_buffer_vec); - // This version is taken from comment in fbs constexpr uint32_t version = 0; // Model auto model_offset = CreateModel(_builder, version, operator_codes, subgraphs, description, - buffers, metadata_buffer); + buffers, 0 /* metadata_buffer */, metadata); FinishModelBuffer(_builder, model_offset); } diff --git a/compiler/luci/export/src/CircleExporterImpl.h b/compiler/luci/export/src/CircleExporterImpl.h index e5d5b5a00..069f62afd 100644 --- a/compiler/luci/export/src/CircleExporterImpl.h +++ b/compiler/luci/export/src/CircleExporterImpl.h @@ -22,8 +22,6 @@ #include "SerializedData.h" -#include "SerializedData.h" - #include <mio/circle/schema_generated.h> #include <loco.h> diff --git a/compiler/luci/export/src/CircleExporterUtils.cpp b/compiler/luci/export/src/CircleExporterUtils.cpp index 3715513e0..1b21fdd86 100644 --- a/compiler/luci/export/src/CircleExporterUtils.cpp +++ b/compiler/luci/export/src/CircleExporterUtils.cpp @@ -208,13 +208,13 @@ circle::Padding getOpPadding(const loco::Padding2D *pad, const loco::Stride<2> * // // NOTE input and output 'feature' map are shape of NHWC bool same_padding_criterion_1 = - (static_cast<uint32_t>(ofm._dims[1]) == (ifm._dims[1] - 1) / stride->vertical() + 1) && - (static_cast<uint32_t>(ofm._dims[2]) == (ifm._dims[2] - 1) / stride->horizontal() + 1); + (static_cast<uint32_t>(ofm._dims[1]) == (ifm._dims[1] - 1) / stride->vertical() + 1) && + (static_cast<uint32_t>(ofm._dims[2]) == (ifm._dims[2] - 1) / stride->horizontal() + 1); // For same padding, rear padding is same or bigger than front padding by at most 1 bool same_padding_criterion_2 = - (pad->top() <= pad->bottom()) && (pad->bottom() <= pad->top() + 1) && - (pad->left() <= pad->right()) && (pad->right() <= pad->left() + 1); + (pad->top() <= pad->bottom()) && (pad->bottom() <= pad->top() + 1) && + (pad->left() <= pad->right()) && (pad->right() <= pad->left() + 1); if (same_padding_criterion_1 && same_padding_criterion_2) return circle::Padding_SAME; diff --git a/compiler/luci/export/src/CircleOperationExporter.cpp b/compiler/luci/export/src/CircleOperationExporter.cpp index 4343cf3c9..4bf674b9b 100644 --- a/compiler/luci/export/src/CircleOperationExporter.cpp +++ b/compiler/luci/export/src/CircleOperationExporter.cpp @@ -21,6 +21,7 @@ #include <luci/IR/CircleNode.h> #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> +#include <luci/Profile/CircleNodeOrigin.h> #include <luci/UserSettings.h> #include <luci/Log.h> @@ -53,8 +54,8 @@ template <class CirclePool2D> void export_pool_2d(ExportContext &ctx, CirclePool2D *node, circle::BuiltinOperator builtin_op) { LUCI_ASSERT(builtin_op == circle::BuiltinOperator_MAX_POOL_2D || - builtin_op == circle::BuiltinOperator_L2_POOL_2D || - builtin_op == circle::BuiltinOperator_AVERAGE_POOL_2D, + builtin_op == circle::BuiltinOperator_L2_POOL_2D || + builtin_op == circle::BuiltinOperator_AVERAGE_POOL_2D, "Should be L2Pool, MaxPool or AvgPool"); LUCI_ASSERT(node->padding() != luci::Padding::UNDEFINED, "Padding is not set"); @@ -81,7 +82,7 @@ void export_node(ExportContext &ctx, loco::Node *node, circle::BuiltinOperator b circle::BuiltinOptions bot, flatbuffers::Offset<void> options_offset) { uint32_t op_idx = - ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version()); + ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version()); std::vector<int32_t> inputs_vec; std::vector<int32_t> outputs_vec{get_tensor_index(node)}; for (uint32_t i = 0; i < node->arity(); ++i) @@ -98,7 +99,7 @@ void export_node(ExportContext &ctx, loco::Node *node, circle::BuiltinOperator b void export_node(ExportContext &ctx, loco::Node *node, circle::BuiltinOperator bop) { uint32_t op_idx = - ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version()); + ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version()); std::vector<int32_t> inputs_vec; std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; for (uint32_t i = 0; i < node->arity(); ++i) @@ -152,7 +153,7 @@ void export_node(ExportContext &ctx, luci::CircleCast *node) void export_node(ExportContext &ctx, luci::CircleConcatenation *node) { uint32_t op_idx = - ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_CONCATENATION, node->op_version()); + ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_CONCATENATION, node->op_version()); std::vector<int32_t> inputs_vec; std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; @@ -171,6 +172,7 @@ void export_node(ExportContext &ctx, luci::CircleConcatenation *node) void export_node(ExportContext &ctx, luci::CircleCustom *node) { auto custom_outputs = loco::succs(node); + assert(custom_outputs.size() == node->numOutputs()); uint32_t op_idx = ctx.md.registerCustomOpcode(node->custom_code()); std::vector<int32_t> inputs_vec; @@ -260,9 +262,9 @@ void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV4 *node) uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_NON_MAX_SUPPRESSION_V4, node->op_version()); std::vector<int32_t> inputs_vec{ - get_tensor_index(node->boxes()), get_tensor_index(node->scores()), - get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()), - get_tensor_index(node->score_threshold()), + get_tensor_index(node->boxes()), get_tensor_index(node->scores()), + get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()), + get_tensor_index(node->score_threshold()), }; std::vector<int32_t> outputs_vec; @@ -290,8 +292,8 @@ void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV4 *node) auto outputs = ctx.builder.CreateVector(outputs_vec); auto options = CreateNonMaxSuppressionV4Options(ctx.builder); auto op_offset = - CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_NonMaxSuppressionV4Options, options.Union()); + CreateOperator(ctx.builder, op_idx, inputs, outputs, + circle::BuiltinOptions_NonMaxSuppressionV4Options, options.Union()); ctx.gd._operators.push_back(op_offset); } @@ -303,9 +305,9 @@ void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV5 *node) uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_NON_MAX_SUPPRESSION_V5, node->op_version()); std::vector<int32_t> inputs_vec{ - get_tensor_index(node->boxes()), get_tensor_index(node->scores()), - get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()), - get_tensor_index(node->score_threshold()), get_tensor_index(node->soft_nms_sigma()), + get_tensor_index(node->boxes()), get_tensor_index(node->scores()), + get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()), + get_tensor_index(node->score_threshold()), get_tensor_index(node->soft_nms_sigma()), }; std::vector<int32_t> outputs_vec; @@ -333,15 +335,15 @@ void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV5 *node) auto outputs = ctx.builder.CreateVector(outputs_vec); auto options = CreateNonMaxSuppressionV5Options(ctx.builder); auto op_offset = - CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_NonMaxSuppressionV5Options, options.Union()); + CreateOperator(ctx.builder, op_idx, inputs, outputs, + circle::BuiltinOptions_NonMaxSuppressionV5Options, options.Union()); ctx.gd._operators.push_back(op_offset); } void export_node(ExportContext &ctx, luci::CircleReverseV2 *node) { uint32_t op_idx = - ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_REVERSE_V2, node->op_version()); + ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_REVERSE_V2, node->op_version()); std::vector<int32_t> inputs_vec{get_tensor_index(node->tensor()), get_tensor_index(node->axis())}; std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; auto inputs = ctx.builder.CreateVector(inputs_vec); @@ -397,7 +399,7 @@ void export_node(ExportContext &ctx, luci::CircleSplitV *node) assert(int32_t(split_outs.size()) == node->num_split()); uint32_t op_idx = - ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT_V, node->op_version()); + ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT_V, node->op_version()); std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->size_splits()), get_tensor_index(node->split_dim())}; @@ -438,7 +440,7 @@ void export_node(ExportContext &ctx, luci::CircleTopKV2 *node) assert(outs_count == 2); uint32_t op_idx = - ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_TOPK_V2, node->op_version()); + ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_TOPK_V2, node->op_version()); std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->k())}; std::vector<int32_t> outputs_vec; @@ -475,7 +477,7 @@ void export_node(ExportContext &ctx, luci::CircleUnique *node) auto unique_outs = loco::succs(node); assert(int32_t(unique_outs.size()) == 2); uint32_t op_idx = - ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNIQUE, node->op_version()); + ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNIQUE, node->op_version()); std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; std::vector<int32_t> outputs_vec; @@ -526,7 +528,7 @@ void export_node(ExportContext &ctx, luci::CircleUnpack *node) } uint32_t op_idx = - ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNPACK, node->op_version()); + ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNPACK, node->op_version()); std::vector<int32_t> inputs_vec{get_tensor_index(node->value())}; std::vector<int32_t> outputs_vec; @@ -622,6 +624,7 @@ public: void visit(luci::CircleAveragePool2D *) final; void visit(luci::CircleBatchMatMul *) final; void visit(luci::CircleBatchToSpaceND *) final; + void visit(luci::CircleBidirectionalSequenceLSTM *) final; void visit(luci::CircleCast *) final; void visit(luci::CircleCeil *) final; void visit(luci::CircleConcatenation *) final; @@ -637,6 +640,7 @@ public: void visit(luci::CircleEqual *) final; void visit(luci::CircleExp *) final; void visit(luci::CircleExpandDims *) final; + void visit(luci::CircleFakeQuant *) final; void visit(luci::CircleFill *) final; void visit(luci::CircleFloor *) final; void visit(luci::CircleFloorDiv *) final; @@ -734,6 +738,7 @@ public: void visit(luci::CircleOutputDummy *) final {} void visit(luci::CircleOutputExclude *) final {} // Virtual for multiple-outputs + void visit(luci::CircleBidirectionalSequenceLSTMOut *) final {} void visit(luci::CircleCustomOut *) final {} void visit(luci::CircleIfOut *) final {} void visit(luci::CircleNonMaxSuppressionV4Out *) final {} @@ -782,8 +787,8 @@ void OperationExporter::visit(luci::CircleAbs *node) void OperationExporter::visit(luci::CircleAdd *node) { export_simple( - node, circle::BuiltinOperator_ADD, circle::BuiltinOptions_AddOptions, - CreateAddOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); + node, circle::BuiltinOperator_ADD, circle::BuiltinOptions_AddOptions, + CreateAddOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); } void OperationExporter::visit(luci::CircleAddN *node) { export_node(_ctx, node); } @@ -791,15 +796,15 @@ void OperationExporter::visit(luci::CircleAddN *node) { export_node(_ctx, node); void OperationExporter::visit(luci::CircleArgMax *node) { export_simple( - node, circle::BuiltinOperator_ARG_MAX, circle::BuiltinOptions_ArgMaxOptions, - CreateArgMaxOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union()); + node, circle::BuiltinOperator_ARG_MAX, circle::BuiltinOptions_ArgMaxOptions, + CreateArgMaxOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union()); } void OperationExporter::visit(luci::CircleArgMin *node) { export_simple( - node, circle::BuiltinOperator_ARG_MIN, circle::BuiltinOptions_ArgMinOptions, - CreateArgMinOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union()); + node, circle::BuiltinOperator_ARG_MIN, circle::BuiltinOptions_ArgMinOptions, + CreateArgMinOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union()); } void OperationExporter::visit(luci::CircleAveragePool2D *node) @@ -814,6 +819,48 @@ void OperationExporter::visit(luci::CircleBatchMatMul *node) CreateBatchMatMulOptions(_ctx.builder, node->adj_x(), node->adj_y()).Union()); } +void OperationExporter::visit(luci::CircleBidirectionalSequenceLSTM *node) +{ + auto bidi_lstm_outs = loco::succs(node); + assert((bidi_lstm_outs.size() == 1) || (bidi_lstm_outs.size() == 2)); + uint32_t op_idx = _ctx.md.registerBuiltinOpcode( + circle::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, node->op_version()); + + std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; + std::vector<int32_t> outputs_vec; + + for (int32_t index = 0; index < 2; index++) + { + // store in order of index + bool found = false; + for (auto out : bidi_lstm_outs) + { + auto bidi_lstm_out = loco::must_cast<luci::CircleBidirectionalSequenceLSTMOut *>(out); + if (bidi_lstm_out->index() == index) + { + outputs_vec.push_back(get_tensor_index(bidi_lstm_out)); + found = true; + break; + } + } + if (!found) + { + INTERNAL_EXN("Invalid BidirectionalSequenceLSTM output"); + } + } + + auto inputs = _ctx.builder.CreateVector(inputs_vec); + auto outputs = _ctx.builder.CreateVector(outputs_vec); + auto options = CreateBidirectionalSequenceLSTMOptions( + _ctx.builder, to_circle_actfunc(node->fusedActivationFunction()), node->cell_clip(), + node->proj_clip(), node->merge_outputs(), node->time_major(), + node->asymmetric_quantize_inputs()); + auto op_offset = + CreateOperator(_ctx.builder, op_idx, inputs, outputs, + circle::BuiltinOptions_BidirectionalSequenceLSTMOptions, options.Union()); + _ctx.gd._operators.push_back(op_offset); +} + void OperationExporter::visit(luci::CircleCast *node) { export_node(_ctx, node); } void OperationExporter::visit(luci::CircleCeil *node) @@ -837,7 +884,7 @@ void OperationExporter::visit(luci::CircleConv2D *node) node->stride()->w(), node->stride()->h(), to_circle_actfunc(node->fusedActivationFunction()), node->dilation()->w(), node->dilation()->h()) - .Union()); + .Union()); } void OperationExporter::visit(luci::CircleCos *node) @@ -857,14 +904,13 @@ void OperationExporter::visit(luci::CircleDepthToSpace *node) void OperationExporter::visit(luci::CircleDepthwiseConv2D *node) { - export_simple(node, circle::BuiltinOperator_DEPTHWISE_CONV_2D, - circle::BuiltinOptions_DepthwiseConv2DOptions, - CreateDepthwiseConv2DOptions(_ctx.builder, getOpPadding(node->padding()), - node->stride()->w(), node->stride()->h(), - node->depthMultiplier(), - to_circle_actfunc(node->fusedActivationFunction()), - node->dilation()->w(), node->dilation()->h()) - .Union()); + export_simple( + node, circle::BuiltinOperator_DEPTHWISE_CONV_2D, circle::BuiltinOptions_DepthwiseConv2DOptions, + CreateDepthwiseConv2DOptions(_ctx.builder, getOpPadding(node->padding()), node->stride()->w(), + node->stride()->h(), node->depthMultiplier(), + to_circle_actfunc(node->fusedActivationFunction()), + node->dilation()->w(), node->dilation()->h()) + .Union()); } void OperationExporter::visit(luci::CircleDequantize *node) @@ -875,8 +921,8 @@ void OperationExporter::visit(luci::CircleDequantize *node) void OperationExporter::visit(luci::CircleDiv *node) { export_simple( - node, circle::BuiltinOperator_DIV, circle::BuiltinOptions_DivOptions, - CreateDivOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); + node, circle::BuiltinOperator_DIV, circle::BuiltinOptions_DivOptions, + CreateDivOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); } void OperationExporter::visit(luci::CircleElu *node) @@ -902,6 +948,14 @@ void OperationExporter::visit(luci::CircleExpandDims *node) CreateExpandDimsOptions(_ctx.builder).Union()); } +void OperationExporter::visit(luci::CircleFakeQuant *node) +{ + export_simple(node, circle::BuiltinOperator_FAKE_QUANT, circle::BuiltinOptions_FakeQuantOptions, + CreateFakeQuantOptions(_ctx.builder, node->min(), node->max(), node->num_bits(), + node->narrow_range()) + .Union()); +} + void OperationExporter::visit(luci::CircleFill *node) { export_simple(node, circle::BuiltinOperator_FILL, circle::BuiltinOptions_FillOptions, @@ -928,10 +982,10 @@ void OperationExporter::visit(luci::CircleFloorMod *node) void OperationExporter::visit(luci::CircleFullyConnected *node) { export_simple( - node, circle::BuiltinOperator_FULLY_CONNECTED, circle::BuiltinOptions_FullyConnectedOptions, - CreateFullyConnectedOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction()), - to_circle_weightsformat(node->weights_format())) - .Union()); + node, circle::BuiltinOperator_FULLY_CONNECTED, circle::BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction()), + to_circle_weightsformat(node->weights_format())) + .Union()); } void OperationExporter::visit(luci::CircleGather *node) @@ -964,9 +1018,8 @@ void OperationExporter::visit(luci::CircleIf *node) { export_node(_ctx, node); } void OperationExporter::visit(luci::CircleL2Normalize *node) { export_simple( - node, circle::BuiltinOperator_L2_NORMALIZATION, circle::BuiltinOptions_L2NormOptions, - CreateL2NormOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())) - .Union()); + node, circle::BuiltinOperator_L2_NORMALIZATION, circle::BuiltinOptions_L2NormOptions, + CreateL2NormOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); } void OperationExporter::visit(luci::CircleL2Pool2D *node) @@ -998,7 +1051,7 @@ void OperationExporter::visit(luci::CircleLocalResponseNormalization *node) circle::BuiltinOptions_LocalResponseNormalizationOptions, CreateLocalResponseNormalizationOptions(_ctx.builder, node->radius(), node->bias(), node->alpha(), node->beta()) - .Union()); + .Union()); } void OperationExporter::visit(luci::CircleLog *node) @@ -1074,15 +1127,15 @@ void OperationExporter::visit(luci::CircleMinimum *node) void OperationExporter::visit(luci::CircleMirrorPad *node) { export_simple( - node, circle::BuiltinOperator_MIRROR_PAD, circle::BuiltinOptions_MirrorPadOptions, - CreateMirrorPadOptions(_ctx.builder, to_circle_mirrorpadmode(node->mode())).Union()); + node, circle::BuiltinOperator_MIRROR_PAD, circle::BuiltinOptions_MirrorPadOptions, + CreateMirrorPadOptions(_ctx.builder, to_circle_mirrorpadmode(node->mode())).Union()); } void OperationExporter::visit(luci::CircleMul *node) { export_simple( - node, circle::BuiltinOperator_MUL, circle::BuiltinOptions_MulOptions, - CreateMulOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); + node, circle::BuiltinOperator_MUL, circle::BuiltinOptions_MulOptions, + CreateMulOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); } void OperationExporter::visit(luci::CircleNeg *node) @@ -1190,7 +1243,7 @@ void OperationExporter::visit(luci::CircleReluN1To1 *node) void OperationExporter::visit(luci::CircleReshape *node) { auto new_shape = _ctx.builder.CreateVector<int32_t>( - node->newShape()->rank(), [node](size_t i) { return node->newShape()->dim(i); }); + node->newShape()->rank(), [node](size_t i) { return node->newShape()->dim(i); }); export_simple(node, circle::BuiltinOperator_RESHAPE, circle::BuiltinOptions_ReshapeOptions, CreateReshapeOptions(_ctx.builder, new_shape).Union()); @@ -1199,9 +1252,9 @@ void OperationExporter::visit(luci::CircleReshape *node) void OperationExporter::visit(luci::CircleResizeBilinear *node) { export_simple( - node, circle::BuiltinOperator_RESIZE_BILINEAR, circle::BuiltinOptions_ResizeBilinearOptions, - CreateResizeBilinearOptions(_ctx.builder, node->align_corners(), node->half_pixel_centers()) - .Union()); + node, circle::BuiltinOperator_RESIZE_BILINEAR, circle::BuiltinOptions_ResizeBilinearOptions, + CreateResizeBilinearOptions(_ctx.builder, node->align_corners(), node->half_pixel_centers()) + .Union()); } void OperationExporter::visit(luci::CircleResizeNearestNeighbor *node) @@ -1214,8 +1267,8 @@ void OperationExporter::visit(luci::CircleResizeNearestNeighbor *node) void OperationExporter::visit(luci::CircleReverseSequence *node) { export_simple( - node, circle::BuiltinOperator_REVERSE_SEQUENCE, circle::BuiltinOptions_ReverseSequenceOptions, - CreateReverseSequenceOptions(_ctx.builder, node->seq_axis(), node->batch_axis()).Union()); + node, circle::BuiltinOperator_REVERSE_SEQUENCE, circle::BuiltinOptions_ReverseSequenceOptions, + CreateReverseSequenceOptions(_ctx.builder, node->seq_axis(), node->batch_axis()).Union()); } void OperationExporter::visit(luci::CircleReverseV2 *node) { export_node(_ctx, node); } @@ -1334,14 +1387,14 @@ void OperationExporter::visit(luci::CircleStridedSlice *node) CreateStridedSliceOptions(_ctx.builder, node->begin_mask(), node->end_mask(), node->ellipsis_mask(), node->new_axis_mask(), node->shrink_axis_mask()) - .Union()); + .Union()); } void OperationExporter::visit(luci::CircleSub *node) { export_simple( - node, circle::BuiltinOperator_SUB, circle::BuiltinOptions_SubOptions, - CreateSubOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); + node, circle::BuiltinOperator_SUB, circle::BuiltinOptions_SubOptions, + CreateSubOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); } void OperationExporter::visit(luci::CircleSum *node) @@ -1375,7 +1428,7 @@ void OperationExporter::visit(luci::CircleTransposeConv *node) circle::BuiltinOptions_TransposeConvOptions, CreateTransposeConvOptions(_ctx.builder, getOpPadding(node->padding()), node->stride()->w(), node->stride()->h()) - .Union()); + .Union()); } void OperationExporter::visit(luci::CircleUnidirectionalSequenceLSTM *node) @@ -1383,10 +1436,10 @@ void OperationExporter::visit(luci::CircleUnidirectionalSequenceLSTM *node) export_simple(node, circle::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, circle::BuiltinOptions_UnidirectionalSequenceLSTMOptions, CreateUnidirectionalSequenceLSTMOptions( - _ctx.builder, to_circle_actfunc(node->fusedActivationFunction()), - node->cell_clip(), node->proj_clip(), node->time_major(), - node->asymmetric_quantize_inputs()) - .Union()); + _ctx.builder, to_circle_actfunc(node->fusedActivationFunction()), + node->cell_clip(), node->proj_clip(), node->time_major(), + node->asymmetric_quantize_inputs()) + .Union()); } void OperationExporter::visit(luci::CircleUnique *node) { export_node(_ctx, node); } @@ -1413,14 +1466,14 @@ void OperationExporter::visit(luci::CircleBCQFullyConnected *node) circle::BuiltinOptions_BCQFullyConnectedOptions, CreateBCQFullyConnectedOptions(_ctx.builder, node->weights_hidden_size(), to_circle_actfunc(node->fusedActivationFunction())) - .Union()); + .Union()); } void OperationExporter::visit(luci::CircleBCQGather *node) { export_simple( - node, circle::BuiltinOperator_BCQ_GATHER, circle::BuiltinOptions_BCQGatherOptions, - CreateBCQGatherOptions(_ctx.builder, node->input_hidden_size(), node->axis()).Union()); + node, circle::BuiltinOperator_BCQ_GATHER, circle::BuiltinOptions_BCQGatherOptions, + CreateBCQGatherOptions(_ctx.builder, node->input_hidden_size(), node->axis()).Union()); } void OperationExporter::visit(luci::CircleInstanceNorm *node) @@ -1429,7 +1482,7 @@ void OperationExporter::visit(luci::CircleInstanceNorm *node) circle::BuiltinOptions_InstanceNormOptions, CreateInstanceNormOptions(_ctx.builder, node->epsilon(), to_circle_actfunc(node->fusedActivationFunction())) - .Union()); + .Union()); } void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &md, @@ -1439,7 +1492,19 @@ void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, Seria { ExportContext ctx{builder, md, gd}; OperationExporter exporter{ctx}; + + const auto ops_size = gd._operators.size(); + circle_node->accept(&exporter); + if (has_origin(circle_node) && ops_size != gd._operators.size()) + { + const auto node_id = gd._operators.size() - 1; + for (auto source : get_origin(circle_node)->sources()) + { + md._metadata.add_source_table(source->id(), source->name()); + md._metadata.add_op_table(node_id, source->id()); + } + } } else { diff --git a/compiler/luci/export/src/CircleTensorExporter.cpp b/compiler/luci/export/src/CircleTensorExporter.cpp index 9bdfa0079..fefdf4e73 100644 --- a/compiler/luci/export/src/CircleTensorExporter.cpp +++ b/compiler/luci/export/src/CircleTensorExporter.cpp @@ -15,11 +15,9 @@ */ #include "CircleTensorExporter.h" -#include "TypeBridge.h" #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> -#include <luci/IR/CircleShapeSignature.h> #include <luci/Service/CircleTypeInference.h> #include <luci/Service/CircleShapeInference.h> #include <luci/Log.h> @@ -38,10 +36,10 @@ namespace using namespace luci; -class CircleTensoInfo +class CircleTensorInfo { public: - CircleTensoInfo() = default; + CircleTensorInfo() = default; public: void name(const std::string &name) { _name = name; } @@ -54,9 +52,6 @@ public: const ShapeDescription &shape(void) const { return _shape; } void shape(const ShapeDescription &shape) { _shape = shape; } - const ShapeSignature &shape_signature(void) const { return _shape_signature; } - void shape_signature(const ShapeSignature &ss) { _shape_signature = ss; } - luci::ShapeStatus shape_status(void) const { return _shape_status; } void shape_status(luci::ShapeStatus ss) { _shape_status = ss; } @@ -75,7 +70,6 @@ private: circle::TensorType _dtype{circle::TensorType_FLOAT32}; ShapeDescription _shape{}; - ShapeSignature _shape_signature; luci::ShapeStatus _shape_status{luci::ShapeStatus::UNDEFINED}; luci::CircleConst *_content = nullptr; @@ -83,7 +77,29 @@ private: luci::SparsityParam *_sparsityparam = nullptr; }; -using CircleTensorContext = std::vector<CircleTensoInfo>; +class CircleTensorContext +{ +public: + CircleTensorContext() = default; + +public: + void emplace_back(CircleTensorInfo &ti) + { + assert(_names.find(ti.name()) == _names.end()); + _tis.emplace_back(ti); + _names.insert(ti.name()); + } + size_t size(void) const { return _tis.size(); } + std::vector<CircleTensorInfo>::iterator begin(void) { return _tis.begin(); } + std::vector<CircleTensorInfo>::iterator end(void) { return _tis.end(); } + +public: + bool exist(const std::string &name) const { return _names.find(name) != _names.end(); } + +private: + std::vector<CircleTensorInfo> _tis; + std::set<std::string> _names; +}; struct NoOpDetector final : public luci::CircleNodeMutableVisitor<bool> { @@ -102,17 +118,23 @@ void allocateCircleTensorInfo(CircleNode *node, CircleTensorContext &ctx) auto tensor_index = static_cast<CircleTensorIndex>(ctx.size()); // TODO Use Graph-level metadata for Input & Output - // auto tensor_name = "t_" + std::to_string(tensor_index); std::string tensor_name = node->name(); - if (tensor_name.empty()) - tensor_name = "t_" + std::to_string(tensor_index); + // NOTE tensor_name maybe empty. this assertion will alert when this happens. + // currently we require tensor should have a name. + // TODO if this breaks, fix the cause or permit empty tensor_name. + assert(!tensor_name.empty()); + if (ctx.exist(tensor_name)) + { + // NOTE this should assign unique name for a Tensor. + tensor_name = tensor_name + "_" + std::to_string(tensor_index); + assert(!ctx.exist(tensor_name)); + } INFO(l) << "[luci] Tensor for " << tensor_name << ": " << tensor_index << std::endl; - CircleTensoInfo tensor_info; + CircleTensorInfo tensor_info; tensor_info.name(tensor_name); tensor_info.dtype(to_circle_tensortype(node->dtype())); - tensor_info.shape_signature(node->shape_signature()); if (node->shape_status() == ShapeStatus::VALID) tensor_info.shape(to_shape_description(node)); tensor_info.shape_status(node->shape_status()); @@ -146,19 +168,55 @@ private: } public: + bool visit(luci::CircleBidirectionalSequenceLSTMOut *) final { return true; } + bool visit(luci::CircleCustomOut *) final { return true; } bool visit(luci::CircleIfOut *) final { return true; } + bool visit(luci::CircleNonMaxSuppressionV4Out *) final { return true; } + bool visit(luci::CircleNonMaxSuppressionV5Out *) final { return true; } bool visit(luci::CircleSplitOut *) final { return true; } bool visit(luci::CircleSplitVOut *) final { return true; } bool visit(luci::CircleTopKV2Out *) final { return true; } bool visit(luci::CircleUnpackOut *) final { return true; } + bool visit(luci::CircleUniqueOut *) final { return true; } bool visit(luci::CircleWhileOut *) final { return true; } + bool visit(luci::CircleBidirectionalSequenceLSTM *node) final + { + if (node->merge_outputs()) + { + store_outputs(node, 1); + } + else + { + store_outputs(node, 2); + } + return true; + } + + bool visit(luci::CircleCustom *node) final + { + store_outputs(node, node->numOutputs()); + return true; + } + bool visit(luci::CircleIf *node) final { store_outputs(node, node->output_count()); return true; } + bool visit(luci::CircleNonMaxSuppressionV4 *node) final + { + store_outputs(node, 2); + return true; + } + + bool visit(luci::CircleNonMaxSuppressionV5 *node) final + { + store_outputs(node, 3); + return true; + } + bool visit(luci::CircleSplit *node) final { store_outputs(node, uint32_t(node->num_split())); @@ -183,6 +241,12 @@ public: return true; } + bool visit(luci::CircleUnique *node) final + { + store_outputs(node, 2); + return true; + } + bool visit(luci::CircleWhile *node) final { store_outputs(node, node->output_count()); @@ -237,16 +301,26 @@ flatbuffers::Offset<Vector<int32_t>> encodeShape(FlatBufferBuilder &builder, const ShapeDescription &shape) { assert(shape._rank_known && "unknown number of dimensions is not supported"); - return builder.CreateVector(shape._dims); + + std::vector<int32_t> encoded_shape; + encoded_shape.resize(shape._dims.size()); + for (uint32_t i = 0; i < shape._dims.size(); ++i) + encoded_shape.at(i) = shape._dims.at(i) == -1 ? 1 : shape._dims.at(i); + + return builder.CreateVector(encoded_shape); } flatbuffers::Offset<Vector<int32_t>> encodeShapeSignature(FlatBufferBuilder &builder, - const ShapeSignature &shape_signature) + const ShapeDescription &shape) { - if (shape_signature.rank() == 0) - return 0; + assert(shape._rank_known && "unknown number of dimensions is not supported"); + + // shape_signature is set if and only if at least one of dimensions are unknown. + for (uint32_t i = 0; i < shape._dims.size(); ++i) + if (shape._dims.at(i) == -1) + return builder.CreateVector(shape._dims); - return builder.CreateVector(shape_signature.as_vector()); + return flatbuffers::Offset<Vector<int32_t>>(); } flatbuffers::Offset<circle::Buffer> encodeOpBuffer(FlatBufferBuilder &builder) @@ -343,14 +417,14 @@ encodeSparsityParameters(FlatBufferBuilder &builder, luci::SparsityParam *sparsi // array_segments auto circle_array_segments = to_circle_sparse_index_vector(builder, it.array_segments()); auto circle_array_segments_type = - to_circle_sparse_index_vector_type(it.array_segments().type()); + to_circle_sparse_index_vector_type(it.array_segments().type()); // array_indices auto circle_array_indices = to_circle_sparse_index_vector(builder, it.array_indices()); auto circle_array_indices_type = to_circle_sparse_index_vector_type(it.array_indices().type()); auto dim_metadata = circle::CreateDimensionMetadata( - builder, to_circle_dimensiontype(it.format()), it.dense_size(), circle_array_segments_type, - circle_array_segments, circle_array_indices_type, circle_array_indices); + builder, to_circle_dimensiontype(it.format()), it.dense_size(), circle_array_segments_type, + circle_array_segments, circle_array_indices_type, circle_array_indices); dim_metadata_vec.emplace_back(dim_metadata); } @@ -358,6 +432,18 @@ encodeSparsityParameters(FlatBufferBuilder &builder, luci::SparsityParam *sparsi &sparsityparam->block_map, &dim_metadata_vec); } +template <loco::DataType DT> bool has_same_elements(luci::CircleConst *lhs, luci::CircleConst *rhs) +{ + assert(lhs->dtype() == DT); + assert(rhs->dtype() == DT); + assert(lhs->size<DT>() == rhs->size<DT>()); + + for (uint32_t i = 0; i < lhs->size<DT>(); ++i) + if (lhs->at<DT>(i) != rhs->at<DT>(i)) + return false; + return true; +} + bool has_same_values(luci::CircleConst *lhs, luci::CircleConst *rhs) { if (lhs->dtype() != rhs->dtype()) @@ -373,34 +459,31 @@ bool has_same_values(luci::CircleConst *lhs, luci::CircleConst *rhs) switch (lhs->dtype()) { case loco::DataType::FLOAT32: - for (uint32_t i = 0; i < lhs->size<loco::DataType::FLOAT32>(); ++i) - if (lhs->at<loco::DataType::FLOAT32>(i) != rhs->at<loco::DataType::FLOAT32>(i)) - return false; - break; + return has_same_elements<loco::DataType::FLOAT32>(lhs, rhs); + + case loco::DataType::S8: + return has_same_elements<loco::DataType::S8>(lhs, rhs); + + case loco::DataType::S16: + return has_same_elements<loco::DataType::S16>(lhs, rhs); case loco::DataType::S32: - for (uint32_t i = 0; i < lhs->size<loco::DataType::S32>(); ++i) - if (lhs->at<loco::DataType::S32>(i) != rhs->at<loco::DataType::S32>(i)) - return false; - break; + return has_same_elements<loco::DataType::S32>(lhs, rhs); case loco::DataType::S64: - for (uint32_t i = 0; i < lhs->size<loco::DataType::S64>(); ++i) - if (lhs->at<loco::DataType::S64>(i) != rhs->at<loco::DataType::S64>(i)) - return false; - break; + return has_same_elements<loco::DataType::S64>(lhs, rhs); + + case loco::DataType::U8: + return has_same_elements<loco::DataType::U8>(lhs, rhs); case loco::DataType::BOOL: - for (uint32_t i = 0; i < lhs->size<loco::DataType::BOOL>(); ++i) - if (lhs->at<loco::DataType::BOOL>(i) != rhs->at<loco::DataType::BOOL>(i)) - return false; - break; + return has_same_elements<loco::DataType::BOOL>(lhs, rhs); default: - return false; + break; } - return true; + return false; } uint32_t get_buffer_id(FlatBufferBuilder &builder, SerializedModelData &md, luci::CircleConst *node) @@ -433,26 +516,28 @@ uint32_t get_buffer_id(FlatBufferBuilder &builder, SerializedModelData &md, luci } } -void exportOpDefinedTensor(const CircleTensoInfo &info, FlatBufferBuilder &builder, +void exportOpDefinedTensor(const CircleTensorInfo &info, FlatBufferBuilder &builder, SerializedModelData &md, SerializedGraphData &gd) { // Create and register output tensor shape flatbuffers::Offset<Vector<int32_t>> shape_offset; + flatbuffers::Offset<Vector<int32_t>> shape_signature_offset; if (info.shape_status() == ShapeStatus::VALID) + { shape_offset = encodeShape(builder, info.shape()); + shape_signature_offset = encodeShapeSignature(builder, info.shape()); + } auto quantparam = encodeQuantizationParameters(builder, info.quantparam()); auto sparsityparam = encodeSparsityParameters(builder, info.sparsityparam()); - auto shape_signature_offset = encodeShapeSignature(builder, info.shape_signature()); - auto buffer_id = get_buffer_id(builder, md, info.content()); auto name_offset = builder.CreateString(info.name()); auto tensor_offset = - CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset, quantparam, - /*is_variable*/ false, sparsityparam, shape_signature_offset); + CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset, quantparam, + /*is_variable*/ false, sparsityparam, shape_signature_offset); gd._tensors.push_back(tensor_offset); } diff --git a/compiler/luci/export/src/Optimize.cpp b/compiler/luci/export/src/Optimize.cpp index 036a4a2f9..e59f15204 100644 --- a/compiler/luci/export/src/Optimize.cpp +++ b/compiler/luci/export/src/Optimize.cpp @@ -17,9 +17,8 @@ #include "Optimize.h" #include "ProgressReporter.h" -#include <luci/Pass/ShapeInferencePass.h> -#include <luci/Pass/ShapeSignatureInferencePass.h> -#include <luci/Pass/TypeInferencePass.h> +#include <luci/Pass/CircleShapeInferencePass.h> +#include <luci/Pass/CircleTypeInferencePass.h> #include <logo/Phase.h> @@ -33,9 +32,8 @@ void optimize(loco::Graph *g) logo::Phase phase; { // prepare type and shape before optimization - phase.emplace_back(std::make_unique<TypeInferencePass>()); - phase.emplace_back(std::make_unique<ShapeInferencePass>()); - phase.emplace_back(std::make_unique<ShapeSignatureInferencePass>()); + phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); + phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); // TODO add more optimization passes (with a knob) } diff --git a/compiler/luci/export/src/ProgressReporter.h b/compiler/luci/export/src/ProgressReporter.h index e91f42592..5d55bcd07 100644 --- a/compiler/luci/export/src/ProgressReporter.h +++ b/compiler/luci/export/src/ProgressReporter.h @@ -28,7 +28,7 @@ class ProgressReporter : public logo::PhaseEventListener { public: ProgressReporter(loco::Graph *graph, logo::PhaseStrategy strategy) - : _graph{graph}, _strategy{strategy} + : _graph{graph}, _strategy{strategy} { // DO NOTHING } diff --git a/compiler/luci/export/src/SerializedData.h b/compiler/luci/export/src/SerializedData.h index c41f50edd..df71e5c21 100644 --- a/compiler/luci/export/src/SerializedData.h +++ b/compiler/luci/export/src/SerializedData.h @@ -48,6 +48,37 @@ struct OpCode } }; +class CircleExportMetadata +{ +public: + void add_source_table(uint32_t source_id, std::string origin_name) + { + // Model with multiple subgraph may have different origin_name + // even if source_id is same. However, as we do not consider about + // multiple subgraph in profiling for now, just do not care those cases + // and support them correctly in the future. + _source_table.emplace(source_id, origin_name); + } + + void add_op_table(uint32_t node_id, uint32_t source_id) + { + // Model with multiple subgraph may have duplicated node id. + // For now, as we do not consider about multiple subgraph in profiling, + // just ignore those cases and support them in the future. + if (_op_table.find(node_id) == _op_table.end()) + _op_table.emplace(node_id, std::set<uint32_t>()); + _op_table.at(node_id).emplace(source_id); + } + +public: + const std::vector<uint8_t> encoded_source_table(void); + const std::vector<uint8_t> encoded_op_table(void); + +private: + std::map<uint32_t, std::string> _source_table; + std::map<uint32_t, std::set<uint32_t>> _op_table; +}; + } // namespace luci namespace std @@ -86,6 +117,7 @@ struct SerializedModelData final std::unordered_map<OpCode, uint32_t> _operator_codes; std::vector<flatbuffers::Offset<circle::Buffer>> _buffers; + CircleExportMetadata _metadata; // This is used for removing buffers with same values std::map<luci::CircleConst *, uint32_t> _cached_buffer_id; diff --git a/compiler/luci/export/src/TypeBridge.cpp b/compiler/luci/export/src/TypeBridge.cpp deleted file mode 100644 index 9ccd52376..000000000 --- a/compiler/luci/export/src/TypeBridge.cpp +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "TypeBridge.h" - -#include "CircleExporterUtils.h" - -#include <luci/IR/CircleNodes.h> -#include <luci/IR/CircleNodeVisitor.h> -#include <luci/Service/CircleTypeInference.h> -#include <luci/Service/CircleShapeInference.h> - -#include <loco/Service/TypeInference.h> -#include <loco/Service/ShapeInference.h> - -namespace -{ - -/** - * @brief CopySelector will return condition of copy shape/type inference to node - */ -struct CopySelector final : public luci::CircleNodeVisitor<bool> -{ - // return false(don't copy) for nodes that provides shape/type from nature - bool visit(const luci::CircleInput *) final { return false; } - bool visit(const luci::CircleConst *) final { return false; } - - // default is copy attributes - bool visit(const luci::CircleNode *) { return true; } -}; - -} // namespace - -namespace luci -{ - -loco::TensorShape node_shape(CircleNode *node) -{ - loco::TensorShape shape; - - shape.rank(node->rank()); - for (uint32_t r = 0; r < node->rank(); ++r) - { - shape.dim(r) = loco::Dimension(node->dim(r).value()); - } - return shape; -} - -loco::DataType node_dtype(CircleNode *node) { return node->dtype(); } - -void copy_shape_dtype(loco::Graph *graph) -{ - /** - * @note We will iterate all the nodes in the graph to include dangle nodes - */ - auto nodes = graph->nodes(); - for (uint32_t n = 0; n < nodes->size(); ++n) - { - auto node = loco::must_cast<luci::CircleNode *>(nodes->at(n)); - - CopySelector cs; - if (node->accept(&cs)) - { - // NOTE not all nodes have infered shape/dtype: multiple outs may not be - // visited when outputs are not used - // TODO fix shape inference traversal - // NOTE when loco supports multiple outputs in nature this issue should be - // resolved also - - if (loco::dtype_known(node)) - { - node->dtype(loco::dtype_get(node)); - } - - if (loco::shape_known(node)) - { - auto shape = loco::shape_get(node).as<loco::TensorShape>(); - node->rank(shape.rank()); - for (uint32_t r = 0; r < shape.rank(); ++r) - { - node->dim(r) = loco::Dimension(shape.dim(r).value()); - } - - // ShapeStatus should be update only when the status was UNDEFINED - if (node->shape_status() == ShapeStatus::UNDEFINED) - node->shape_status(ShapeStatus::VALID); - } - } - } -} - -} // namespace luci diff --git a/compiler/luci/import/CMakeLists.txt b/compiler/luci/import/CMakeLists.txt index 2ae00b837..642751ca6 100644 --- a/compiler/luci/import/CMakeLists.txt +++ b/compiler/luci/import/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(luci_import SHARED ${SOURCES}) target_include_directories(luci_import PRIVATE src) target_include_directories(luci_import PUBLIC include) target_link_libraries(luci_import PUBLIC luci_lang) +target_link_libraries(luci_import PUBLIC luci_profile) target_link_libraries(luci_import PUBLIC mio_circle) target_link_libraries(luci_import PRIVATE luci_env) target_link_libraries(luci_import PRIVATE luci_log) diff --git a/compiler/luci/import/include/luci/Import/CircleReader.h b/compiler/luci/import/include/luci/Import/CircleReader.h index 8e210dd77..b9697fb86 100644 --- a/compiler/luci/import/include/luci/Import/CircleReader.h +++ b/compiler/luci/import/include/luci/Import/CircleReader.h @@ -23,7 +23,6 @@ #include <luci/IR/AttrPadding.h> #include <luci/IR/CircleNode.h> #include <luci/IR/CircleQuantParam.h> -#include <luci/IR/CircleShapeSignature.h> #include <luci/IR/SparsityParam.h> #include <loco.h> @@ -64,6 +63,7 @@ private: using CircleTensors_t = std::vector<std::unique_ptr<circle::TensorT>>; using CircleOperators_t = std::vector<std::unique_ptr<circle::OperatorT>>; using CircleOperatorCodes_t = std::vector<std::unique_ptr<circle::OperatorCodeT>>; + using CircleMetadata_t = std::vector<std::unique_ptr<circle::MetadataT>>; using CircleSubGraphsPtr_t = flatbuffers::Vector<flatbuffers::Offset<circle::SubGraph>>; using CircleTensorsPtr_t = flatbuffers::Vector<flatbuffers::Offset<circle::Tensor>>; @@ -79,6 +79,8 @@ public: const std::vector<int32_t> &inputs() const { return _current_subgraph->inputs; } const std::vector<int32_t> &outputs() const { return _current_subgraph->outputs; } const std::string &name() const { return _current_subgraph->name; } + const circle::DataFormat &data_format() const { return _current_subgraph->data_format; } + const CircleMetadata_t &metadata() const { return _model->metadata; } const CircleTensorsPtr_t *tensors_ptr() const { return _tensors_ptr; } diff --git a/compiler/luci/import/include/luci/Import/GraphBuilder.h b/compiler/luci/import/include/luci/Import/GraphBuilder.h index 548264dac..0db612652 100644 --- a/compiler/luci/import/include/luci/Import/GraphBuilder.h +++ b/compiler/luci/import/include/luci/Import/GraphBuilder.h @@ -33,7 +33,13 @@ class GraphBuilder : public GraphBuilderBase public: virtual ~GraphBuilder() = default; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; + // common validate method to check number of inputs and single output + bool validate(const ValidateArgs &args, size_t input_cnt) const + { + return (args.op.inputs.size() == input_cnt && args.op.outputs.size() == 1); + } + + CircleNode *build(const circle::OperatorT &op, GraphBuilderContext *context) const final; private: virtual CircleNode *build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/include/luci/Import/GraphBuilderBase.h b/compiler/luci/import/include/luci/Import/GraphBuilderBase.h index a0cd008e0..ddd4445cd 100644 --- a/compiler/luci/import/include/luci/Import/GraphBuilderBase.h +++ b/compiler/luci/import/include/luci/Import/GraphBuilderBase.h @@ -19,6 +19,8 @@ #include "GraphBuilderContext.h" +#include <luci/IR/CircleNode.h> + #include <mio/circle/schema_generated.h> namespace luci @@ -38,7 +40,7 @@ struct GraphBuilderBase }; virtual bool validate(const ValidateArgs &) const = 0; - virtual void build(const circle::OperatorT &op, GraphBuilderContext *context) const = 0; + virtual CircleNode *build(const circle::OperatorT &op, GraphBuilderContext *context) const = 0; virtual ~GraphBuilderBase() = default; }; diff --git a/compiler/luci/import/include/luci/Import/GraphBuilderContext.h b/compiler/luci/import/include/luci/Import/GraphBuilderContext.h index 72e237abc..1673df43d 100644 --- a/compiler/luci/import/include/luci/Import/GraphBuilderContext.h +++ b/compiler/luci/import/include/luci/Import/GraphBuilderContext.h @@ -71,7 +71,7 @@ class GraphBuilderContext public: GraphBuilderContext(loco::Graph *g, CircleReader *reader, IndexNodeFinder *nodefinder, IndexTensorOutputs *tensoroutputs) - : _g(g), _reader(reader), _indexnodefinder(nodefinder), _indextensoroutputs(tensoroutputs) + : _g(g), _reader(reader), _indexnodefinder(nodefinder), _indextensoroutputs(tensoroutputs) { // DO NOTHING } diff --git a/compiler/luci/import/include/luci/Import/GraphBuilderMultiOutput.h b/compiler/luci/import/include/luci/Import/GraphBuilderMultiOutput.h new file mode 100644 index 000000000..6e8791b62 --- /dev/null +++ b/compiler/luci/import/include/luci/Import/GraphBuilderMultiOutput.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_GRAPH_BUILDER_MULTI_OUTPUT_H__ +#define __LUCI_IMPORT_GRAPH_BUILDER_MULTI_OUTPUT_H__ + +#include "GraphBuilderContext.h" +#include "GraphBuilderBase.h" + +#include <mio/circle/schema_generated.h> + +namespace luci +{ + +/** + * @brief Base of general multiple outputs graph builder(e.g., CircleIfGraphBuilder) + */ +class GraphBuilderMultiOutput : public GraphBuilderBase +{ +public: + virtual ~GraphBuilderMultiOutput() = default; + + CircleNode *build(const circle::OperatorT &op, GraphBuilderContext *context) const final; + +protected: + struct BuildNodeArgs + { + BuildNodeArgs(const circle::OperatorT &o, GraphBuilderContext *c, + const std::vector<CircleNode *> &i) + : op(o), context(c), input_nodes(i) + { + } + + const circle::OperatorT &op; + GraphBuilderContext *context; + const std::vector<CircleNode *> &input_nodes; + }; + + struct BuildOutArgs + { + BuildOutArgs(CircleNode *nd, uint32_t n) : node(nd), index(n) {} + + CircleNode *node; + uint32_t index; + }; + +private: + virtual CircleNode *build_node(const BuildNodeArgs &) const = 0; + virtual CircleNode *build_out(const BuildOutArgs &) const = 0; +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_GRAPH_BUILDER_MULTI_OUTPUT_H__ diff --git a/compiler/luci/import/include/luci/Import/Nodes.h b/compiler/luci/import/include/luci/Import/Nodes.h index 28741064e..b084c7dbc 100644 --- a/compiler/luci/import/include/luci/Import/Nodes.h +++ b/compiler/luci/import/include/luci/Import/Nodes.h @@ -27,6 +27,7 @@ #include "Nodes/CircleBatchToSpaceND.h" #include "Nodes/CircleBCQFullyConnected.h" #include "Nodes/CircleBCQGather.h" +#include "Nodes/CircleBidirectionalSequenceLSTM.h" #include "Nodes/CircleCast.h" #include "Nodes/CircleCeil.h" #include "Nodes/CircleConcatenation.h" @@ -42,6 +43,7 @@ #include "Nodes/CircleEqual.h" #include "Nodes/CircleExp.h" #include "Nodes/CircleExpandDims.h" +#include "Nodes/CircleFakeQuant.h" #include "Nodes/CircleFill.h" #include "Nodes/CircleFloor.h" #include "Nodes/CircleFloorDiv.h" diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h b/compiler/luci/import/include/luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h new file mode 100644 index 000000000..491517268 --- /dev/null +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_OP_CIRCLE_BIDIRECTIONALSEQUENCE_LSTM_H__ +#define __LUCI_IMPORT_OP_CIRCLE_BIDIRECTIONALSEQUENCE_LSTM_H__ + +#include "luci/Import/GraphBuilderMultiOutput.h" + +namespace luci +{ + +class CircleBidirectionalSequenceLSTMGraphBuilder : public GraphBuilderMultiOutput +{ +public: + bool validate(const ValidateArgs &args) const final; + +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_OP_CIRCLE_BIDIRECTIONALSEQUENCE_LSTM_H__ diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleCustom.h b/compiler/luci/import/include/luci/Import/Nodes/CircleCustom.h index 65745be4b..f0d7e303d 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleCustom.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleCustom.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_CUSTOM_H__ #define __LUCI_IMPORT_OP_CIRCLE_CUSTOM_H__ -#include "luci/Import/GraphBuilder.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleCustomGraphBuilder : public GraphBuilderBase +class CircleCustomGraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleFakeQuant.h b/compiler/luci/import/include/luci/Import/Nodes/CircleFakeQuant.h new file mode 100644 index 000000000..9d9f7b07b --- /dev/null +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleFakeQuant.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_OP_CIRCLE_FAKE_QUANT_H__ +#define __LUCI_IMPORT_OP_CIRCLE_FAKE_QUANT_H__ + +#include "luci/Import/GraphBuilder.h" + +namespace luci +{ + +class CircleFakeQuantGraphBuilder : public GraphBuilder +{ +public: + bool validate(const ValidateArgs &args) const final; + +private: + CircleNode *build_node(const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, + loco::Graph *graph) const final; +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_OP_CIRCLE_FAKE_QUANT_H__ diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleIf.h b/compiler/luci/import/include/luci/Import/Nodes/CircleIf.h index 8faf09cae..94052f5be 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleIf.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleIf.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_IF_H__ #define __LUCI_IMPORT_OP_CIRCLE_IF_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleIfGraphBuilder : public GraphBuilderBase +class CircleIfGraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV4.h b/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV4.h index f193aae35..4e8388b3e 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV4.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV4.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_NON_MAX_SUPPRESSION_V4_H__ #define __LUCI_IMPORT_OP_CIRCLE_NON_MAX_SUPPRESSION_V4_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleNonMaxSuppressionV4GraphBuilder : public GraphBuilderBase +class CircleNonMaxSuppressionV4GraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV5.h b/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV5.h index 62be0758e..4120a30eb 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV5.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV5.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_NON_MAX_SUPPRESSION_V5_H__ #define __LUCI_IMPORT_OP_CIRCLE_NON_MAX_SUPPRESSION_V5_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleNonMaxSuppressionV5GraphBuilder : public GraphBuilderBase +class CircleNonMaxSuppressionV5GraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleSplit.h b/compiler/luci/import/include/luci/Import/Nodes/CircleSplit.h index 3395e40fd..5b45c9a9e 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleSplit.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleSplit.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_SPLIT_H__ #define __LUCI_IMPORT_OP_CIRCLE_SPLIT_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleSplitGraphBuilder : public GraphBuilderBase +class CircleSplitGraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleSplitV.h b/compiler/luci/import/include/luci/Import/Nodes/CircleSplitV.h index 3e53df362..de712f90c 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleSplitV.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleSplitV.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_SPLIT_V_H__ #define __LUCI_IMPORT_OP_CIRCLE_SPLIT_V_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleSplitVGraphBuilder : public GraphBuilderBase +class CircleSplitVGraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleTopKV2.h b/compiler/luci/import/include/luci/Import/Nodes/CircleTopKV2.h index 8ec3f3311..b4ad97130 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleTopKV2.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleTopKV2.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_TOPK_V2_H__ #define __LUCI_IMPORT_OP_CIRCLE_TOPK_V2_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleTopKV2GraphBuilder : public GraphBuilderBase +class CircleTopKV2GraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleUnique.h b/compiler/luci/import/include/luci/Import/Nodes/CircleUnique.h index ed5b5035d..40e75ec73 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleUnique.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleUnique.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_UNIQUE_H__ #define __LUCI_IMPORT_OP_CIRCLE_UNIQUE_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleUniqueGraphBuilder : public GraphBuilderBase +class CircleUniqueGraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleUnpack.h b/compiler/luci/import/include/luci/Import/Nodes/CircleUnpack.h index f1a21de22..0b623655f 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleUnpack.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleUnpack.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_UNPACK_H__ #define __LUCI_IMPORT_OP_CIRCLE_UNPACK_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleUnpackGraphBuilder : public GraphBuilderBase +class CircleUnpackGraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleWhile.h b/compiler/luci/import/include/luci/Import/Nodes/CircleWhile.h index 68c56b3c6..69d23f823 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleWhile.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleWhile.h @@ -27,7 +27,7 @@ class CircleWhileGraphBuilder : public GraphBuilderBase public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; + CircleNode *build(const circle::OperatorT &op, GraphBuilderContext *context) const final; }; } // namespace luci diff --git a/compiler/luci/import/src/CircleImportMetadata.cpp b/compiler/luci/import/src/CircleImportMetadata.cpp new file mode 100644 index 000000000..f68f3301a --- /dev/null +++ b/compiler/luci/import/src/CircleImportMetadata.cpp @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleImportMetadata.h" + +#include <vector> + +namespace +{ + +uint32_t read_u32(const std::vector<uint8_t> &buffer, uint32_t idx) +{ + uint32_t val = 0; + val += (buffer.at(idx + 0) << 0 * 8); + val += (buffer.at(idx + 1) << 1 * 8); + val += (buffer.at(idx + 2) << 2 * 8); + val += (buffer.at(idx + 3) << 3 * 8); + return val; +} + +} // namespace + +namespace +{ + +// 'source_table' is decoded to std::map<uint32_t, std::string> format. +const std::map<uint32_t, std::string> +decoded_source_table(const std::vector<uint8_t> &source_table_data) +{ + std::map<uint32_t, std::string> source_id_name_map; + uint32_t idx = 0; + + if (source_table_data.size() < 4) + throw std::runtime_error("Source table decode error : invalid entry number"); + + uint32_t entry_number = read_u32(source_table_data, idx); + idx += sizeof(uint32_t); + + while (idx < source_table_data.size()) + { + if (idx + 2 * sizeof(uint32_t) > source_table_data.size()) + throw std::runtime_error("Source table decode error : invalid entry item"); + + uint32_t id = read_u32(source_table_data, idx); + idx += sizeof(uint32_t); + + uint32_t length = read_u32(source_table_data, idx); + idx += sizeof(uint32_t); + + if (idx + sizeof(char) * length > source_table_data.size()) + throw std::runtime_error("Source table decode error : invalid entry data"); + + // The last character of name is '\0'. + // However, as std::string do not use '\0' for finding the end of string, + // we ignore the character and do not include it in the string. + std::string origin_name; + for (uint32_t j = 0; j < length - 1; ++j) + origin_name += source_table_data.at(idx + j); + assert(source_table_data.at(idx + length - 1) == '\0'); + idx += sizeof(char) * length; + + if (source_id_name_map.insert({id, origin_name}).second == false) + throw std::runtime_error("Source table decode error : duplicated origin ID"); + } + + if (idx != source_table_data.size()) + throw std::runtime_error("Source table decode error : data size invalid"); + + if (source_id_name_map.size() != entry_number) + throw std::runtime_error("Source table decode error : result size mismatch"); + + return source_id_name_map; +} + +// 'op_table' is decoded to std::map<uint32_t, std::set<uint32_t>> format. +const std::map<uint32_t, std::set<uint32_t>> +decoded_op_table(const std::vector<uint8_t> &op_table_data) +{ + std::map<uint32_t, std::set<uint32_t>> node_source_ids_map; + uint32_t idx = 0; + + if (op_table_data.size() < 4) + throw std::runtime_error("Op table decode error : invalid entry number"); + + uint32_t entry_number = read_u32(op_table_data, idx); + idx += sizeof(uint32_t); + + while (idx < op_table_data.size()) + { + if (idx + 2 * sizeof(uint32_t) > op_table_data.size()) + throw std::runtime_error("Op table decode error : invalid entry item"); + + uint32_t id = read_u32(op_table_data, idx); + idx += sizeof(uint32_t); + + uint32_t node_num = read_u32(op_table_data, idx); + idx += sizeof(uint32_t); + + if (idx + sizeof(uint32_t) * node_num > op_table_data.size()) + throw std::runtime_error("Source table decode error : invalid entry data"); + + std::set<uint32_t> source_ids; + for (uint32_t j = 0; j < node_num; ++j) + { + uint32_t origin = read_u32(op_table_data, idx); + idx += sizeof(uint32_t); + + source_ids.insert(origin); + } + + if (node_source_ids_map.insert({id, source_ids}).second == false) + throw std::runtime_error("Op table decode error : duplicated origin ID"); + } + + if (idx != op_table_data.size()) + throw std::runtime_error("Op table decode error : data size invalid"); + + if (node_source_ids_map.size() != entry_number) + throw std::runtime_error("Op table decode error : entry number invalid"); + + return node_source_ids_map; +} + +} // namespace + +namespace luci +{ + +CircleImportMetadata::CircleImportMetadata(const luci::CircleReader &reader) +{ + const auto &metadata = reader.metadata(); + for (uint32_t i = 0; i < metadata.size(); ++i) + { + const circle::MetadataT &meta = *metadata[i]; + + assert(meta.buffer < reader.buffers().size()); + const std::vector<uint8_t> &buffer = reader.buffers()[meta.buffer]->data; + + if (meta.name.compare("ONE_op_table") == 0) + _op_table = decoded_op_table(buffer); + else if (meta.name.compare("ONE_source_table") == 0) + _source_table = decoded_source_table(buffer); + } +} + +const OriginTable CircleImportMetadata::origin_table(void) +{ + OriginTable origin_table; + + if (_op_table.size() > 0 && _source_table.size() > 0) + { + for (auto &kv : _op_table) + { + const auto node_id = kv.first; + const auto &source_ids = kv.second; + + std::vector<std::shared_ptr<CircleNodeOrigin>> origins; + for (auto source_id : source_ids) + { + const auto source_name = _source_table.at(source_id); + origins.push_back(single_origin(source_id, source_name)); + } + + auto origin = composite_origin(origins); + origin_table.emplace(node_id, origin); + } + } + + return origin_table; +} + +} // namespace luci diff --git a/compiler/luci/import/src/CircleImportMetadata.h b/compiler/luci/import/src/CircleImportMetadata.h new file mode 100644 index 000000000..80176db94 --- /dev/null +++ b/compiler/luci/import/src/CircleImportMetadata.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_CIRCLE_IMPORT_METADATA_H__ +#define __LUCI_CIRCLE_IMPORT_METADATA_H__ + +#include "luci/Import/CircleReader.h" + +#include <luci/Profile/CircleNodeOrigin.h> + +#include <map> +#include <set> +#include <string> + +namespace luci +{ + +using OriginTable = std::map<uint32_t, std::shared_ptr<CircleNodeOrigin>>; + +class CircleImportMetadata +{ +public: + CircleImportMetadata() = delete; + + CircleImportMetadata(const luci::CircleReader &reader); + +public: + /** + * @brief Create origin table using _source_table and _op_table in CircleImportMetadata + * @note For creating origin table, both _op_table and _source_table should exist. + * If one of them does not exist, empty table is returned. + */ + const OriginTable origin_table(void); + +private: + // Decoded metadata is stored + std::map<uint32_t, std::string> _source_table; + std::map<uint32_t, std::set<uint32_t>> _op_table; +}; + +} // namespace luci + +#endif // __LUCI_CIRCLE_IMPORT_METADATA_H__ diff --git a/compiler/luci/import/src/CircleReader.cpp b/compiler/luci/import/src/CircleReader.cpp index b33c920b1..861c1bbe3 100644 --- a/compiler/luci/import/src/CircleReader.cpp +++ b/compiler/luci/import/src/CircleReader.cpp @@ -190,19 +190,19 @@ luci_sparse_index_vector(const circle::SparseIndexVectorUnion &sparse_index_vect case circle::SparseIndexVector_Int32Vector: { const auto const_vec_ptr = - static_cast<const void *>(&(sparse_index_vector.AsInt32Vector()->values)); + static_cast<const void *>(&(sparse_index_vector.AsInt32Vector()->values)); return SparseIndexVector{SparseIndexVectorType::I32, const_vec_ptr}; } case circle::SparseIndexVector_Uint16Vector: { const auto const_vec_ptr = - static_cast<const void *>(&(sparse_index_vector.AsUint16Vector()->values)); + static_cast<const void *>(&(sparse_index_vector.AsUint16Vector()->values)); return SparseIndexVector{SparseIndexVectorType::U16, const_vec_ptr}; } case circle::SparseIndexVector_Uint8Vector: { const auto const_vec_ptr = - static_cast<const void *>(&(sparse_index_vector.AsUint8Vector()->values)); + static_cast<const void *>(&(sparse_index_vector.AsUint8Vector()->values)); return SparseIndexVector{SparseIndexVectorType::U8, const_vec_ptr}; } default: @@ -262,15 +262,19 @@ void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node) node->name(tensor_name(tensor)); node->dtype(luci_datatype(tensor.type)); + assert(tensor.shape_signature.size() == 0 || + tensor.shape_signature.size() == tensor.shape.size()); + std::vector<int32_t> dims = tensor.shape; // in NHWC node->rank(dims.size()); for (uint32_t r = 0; r < dims.size(); ++r) { - node->dim(r) = loco::Dimension(dims[r]); + if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1) + node->dim(r).unset(); + else + node->dim(r).set(dims[r]); } - node->shape_signature(tensor.shape_signature); - const auto *quantization = tensor.quantization.get(); if (quantization != nullptr) { diff --git a/compiler/luci/import/src/GraphBuilder.cpp b/compiler/luci/import/src/GraphBuilder.cpp index 80a9f986a..356501c2f 100644 --- a/compiler/luci/import/src/GraphBuilder.cpp +++ b/compiler/luci/import/src/GraphBuilder.cpp @@ -21,7 +21,7 @@ namespace luci { -void GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const +CircleNode *GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const { LOGGER(l); @@ -47,7 +47,11 @@ void GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *conte else { // If there is no tensor, insert CircleOutputExclude. - input_nodes.push_back(context->graph()->nodes()->create<luci::CircleOutputExclude>()); + auto *node = context->graph()->nodes()->create<luci::CircleOutputExclude>(); + // CircleOutputExclude doesn't need a type, but since all nodes must have a type, + // a dummy type is inserted. + node->dtype(loco::DataType::FLOAT32); + input_nodes.push_back(node); } } @@ -73,6 +77,8 @@ void GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *conte { context->nodefinder()->enroll(outputs[0], node); } + + return node; } } // namespace luci diff --git a/compiler/luci/import/src/GraphBuilderMultiOutput.cpp b/compiler/luci/import/src/GraphBuilderMultiOutput.cpp new file mode 100644 index 000000000..9b42e997e --- /dev/null +++ b/compiler/luci/import/src/GraphBuilderMultiOutput.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/GraphBuilderMultiOutput.h" + +#include <luci/Log.h> + +namespace luci +{ + +CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op, + GraphBuilderContext *context) const +{ + LOGGER(l); + + assert(context != nullptr); + + const std::vector<int32_t> &inputs = op.inputs; + const std::vector<int32_t> &outputs = op.outputs; + const auto &tensors = context->reader()->tensors(); + const auto &opcodes = context->reader()->opcodes(); + auto tensors_ptr = context->reader()->tensors_ptr(); + assert(tensors_ptr != nullptr); + + std::vector<CircleNode *> input_nodes; + for (const int32_t input_tensor_index : inputs) + { + if (input_tensor_index >= 0) + { + auto input = context->nodefinder()->node(input_tensor_index); + if (input == nullptr) + INFO(l) << "[luci] Warning: input node is null " << input_tensor_index << std::endl; + input_nodes.push_back(input); + } + else + { + // If there is no tensor, insert CircleOutputExclude. + auto *node = context->graph()->nodes()->create<luci::CircleOutputExclude>(); + // CircleOutputExclude doesn't need a type, but since all nodes must have a type, + // a dummy type is inserted. + node->dtype(loco::DataType::FLOAT32); + input_nodes.push_back(node); + } + } + + BuildNodeArgs bna(op, context, input_nodes); + auto *node = build_node(bna); + + uint32_t output_count = outputs.size(); + assert(output_count > 0); + { + // Let's use attributes from output 0 for this node + const circle::TensorT &output_tensor = *tensors[outputs[0]]; + node->name(tensor_name(output_tensor)); + node->dtype(luci_datatype(output_tensor.type)); + + // mark operator version + node->op_version(opcodes[op.opcode_index].get()->version); + + // NOTE We don't set quantization for multiple output nodes but to virtual outputs + } + + // Create virtual outputs of Virtual Output node(s) + for (uint32_t n = 0; n < output_count; ++n) + { + const circle::TensorT &output_tensor = *tensors[outputs[n]]; + + BuildOutArgs boa(node, n); + auto *nodeout = build_out(boa); + + copy_tensor_attributes(output_tensor, nodeout); + // mark shape_status + if (tensors_ptr->Get(outputs[n])->shape() == nullptr) + nodeout->shape_status(ShapeStatus::NOSHAPE); + else + nodeout->shape_status(ShapeStatus::VALID); + + context->nodefinder()->enroll(outputs[n], nodeout); + } + + return node; +} + +} // namespace luci diff --git a/compiler/luci/import/src/GraphBuilderRegistry.cpp b/compiler/luci/import/src/GraphBuilderRegistry.cpp index d598d30f4..7f98aab78 100644 --- a/compiler/luci/import/src/GraphBuilderRegistry.cpp +++ b/compiler/luci/import/src/GraphBuilderRegistry.cpp @@ -37,6 +37,7 @@ GraphBuilderRegistry::GraphBuilderRegistry() CIRCLE_NODE(BATCH_TO_SPACE_ND, CircleBatchToSpaceNDGraphBuilder); // 37 CIRCLE_NODE(BCQ_FULLY_CONNECTED, CircleBCQFullyConnectedGraphBuilder); // 253 CIRCLE_NODE(BCQ_GATHER, CircleBCQGatherGraphBuilder); // 252 + CIRCLE_NODE(BIDIRECTIONAL_SEQUENCE_LSTM, CircleBidirectionalSequenceLSTMGraphBuilder); // 52 CIRCLE_NODE(CAST, CircleCastGraphBuilder); // 53 CIRCLE_NODE(CEIL, CircleCeilGraphBuilder); // 104 CIRCLE_NODE(CUSTOM, CircleCustomGraphBuilder); // 32 @@ -51,6 +52,7 @@ GraphBuilderRegistry::GraphBuilderRegistry() CIRCLE_NODE(EQUAL, CircleEqualGraphBuilder); // 71 CIRCLE_NODE(EXP, CircleExpGraphBuilder); // 47 CIRCLE_NODE(EXPAND_DIMS, CircleExpandDimsGraphBuilder); // 70 + CIRCLE_NODE(FAKE_QUANT, CircleFakeQuantGraphBuilder); // 80 CIRCLE_NODE(FILL, CircleFillGraphBuilder); // 94 CIRCLE_NODE(FLOOR, CircleFloorGraphBuilder); // 8 CIRCLE_NODE(FLOOR_DIV, CircleFloorDivGraphBuilder); // 90 @@ -155,9 +157,7 @@ GraphBuilderRegistry::GraphBuilderRegistry() // BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN = 35, // BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN = 46, // BuiltinOperator_DELEGATE = 51, - // BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM = 52, // BuiltinOperator_ARG_MAX = 56, - // BuiltinOperator_FAKE_QUANT = 80, // BuiltinOperator_QUANTIZE = 114, // BuiltinOperator_HARD_SWISH = 117, // BuiltinOperator_DENSIFY = 124, diff --git a/compiler/luci/import/src/Importer.cpp b/compiler/luci/import/src/Importer.cpp index ab89f3587..193afffcb 100644 --- a/compiler/luci/import/src/Importer.cpp +++ b/compiler/luci/import/src/Importer.cpp @@ -15,6 +15,7 @@ */ #include "luci/Importer.h" +#include "CircleImportMetadata.h" #include "PostImport.h" #include "luci/Import/GraphBuilder.h" @@ -25,6 +26,8 @@ #include <luci/IR/Module.h> #include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeID.h> +#include <luci/Profile/CircleNodeOrigin.h> #include <luci/Log.h> #include <luci/LogHelper.h> @@ -50,6 +53,7 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r const auto &tensors = reader.tensors(); auto tensors_ptr = reader.tensors_ptr(); assert(tensors_ptr != nullptr); + auto circle_metadata = std::make_unique<luci::CircleImportMetadata>(reader); // build a cache to identify if a tensor is output of an operator // if this is set, we should not create a CircleConst for this tensor @@ -96,12 +100,20 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r // Data type graph_input->dtype(input_node->dtype()); + assert(tensor.shape_signature.size() == 0 || + tensor.shape_signature.size() == tensor.shape.size()); + // Shape of GraphInput auto input_shape = std::make_unique<loco::TensorShape>(); const std::vector<int32_t> &input_dims = tensor.shape; // in NHWC input_shape->rank(input_dims.size()); for (uint32_t r = 0; r < input_dims.size(); ++r) - input_shape->dim(r) = loco::Dimension(input_dims[r]); + { + if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1) + input_shape->dim(r).unset(); + else + input_shape->dim(r).set(input_dims[r]); + } graph_input->shape(std::move(input_shape)); } @@ -117,6 +129,7 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r // Note that operators in model are stored in execution order. This means that when importing // an operator, its input operators have already been imported. We exploit this fact to set up // node's inputs right after creating the node. + auto origin_table = circle_metadata->origin_table(); for (uint32_t i = 0; i < operators.size(); ++i) { const circle::OperatorT &op = *operators[i]; @@ -130,7 +143,12 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r throw oops::UserExn("Invalid operator", reader.opcode_name(op)); } - builder->build(op, &gb_context); + auto built_op = builder->build(op, &gb_context); + set_node_id(built_op, i); + if (origin_table.find(i) != origin_table.end()) + add_origin(built_op, origin_table.at(i)); + else + add_origin(built_op, luci::single_origin(i, built_op->name())); } else { @@ -169,19 +187,28 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r // set the graph output name and node object auto graph_output = graph->outputs()->create(); std::string tname = luci::tensor_name(tensor); - graph_output->name("output_" + tname); + assert(tname.length() > 0); + graph_output->name(tname); luci::copy_tensor_attributes(tensor, output_node); // Set GraphInputOutputIndex for graph output_node->index(graph_output->index()); + assert(tensor.shape_signature.size() == 0 || + tensor.shape_signature.size() == tensor.shape.size()); + // Shape of Output auto output_shape = std::make_unique<loco::TensorShape>(); const std::vector<int32_t> &output_dims = tensor.shape; // in NHWC output_shape->rank(output_dims.size()); for (uint32_t r = 0; r < output_dims.size(); ++r) - output_shape->dim(r) = loco::Dimension(output_dims[r]); + { + if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1) + output_shape->dim(r).unset(); + else + output_shape->dim(r).set(output_dims[r]); + } graph_output->shape(std::move(output_shape)); // Data type diff --git a/compiler/luci/import/src/Nodes/CircleAbs.cpp b/compiler/luci/import/src/Nodes/CircleAbs.cpp index 3556dc7fa..2a1601a21 100644 --- a/compiler/luci/import/src/Nodes/CircleAbs.cpp +++ b/compiler/luci/import/src/Nodes/CircleAbs.cpp @@ -24,11 +24,8 @@ namespace luci { bool CircleAbsGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - // TODO Support type check - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleAbsGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleAdd.cpp b/compiler/luci/import/src/Nodes/CircleAdd.cpp index b767d4af2..94cbdf081 100644 --- a/compiler/luci/import/src/Nodes/CircleAdd.cpp +++ b/compiler/luci/import/src/Nodes/CircleAdd.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleAddGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleAddGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleArgMax.cpp b/compiler/luci/import/src/Nodes/CircleArgMax.cpp index 10e8516f4..fd8a84289 100644 --- a/compiler/luci/import/src/Nodes/CircleArgMax.cpp +++ b/compiler/luci/import/src/Nodes/CircleArgMax.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleArgMaxGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleArgMaxGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleArgMin.cpp b/compiler/luci/import/src/Nodes/CircleArgMin.cpp index 5ff534dbb..63ca8db03 100644 --- a/compiler/luci/import/src/Nodes/CircleArgMin.cpp +++ b/compiler/luci/import/src/Nodes/CircleArgMin.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleArgMinGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleArgMinGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp b/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp index ad011f71f..a351cf5e7 100644 --- a/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp +++ b/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp @@ -23,10 +23,7 @@ namespace luci bool CircleAveragePool2DGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleAveragePool2DGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp index 16ecebd5c..4c86399ce 100644 --- a/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp +++ b/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleBCQFullyConnectedGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 5) - return false; - - return true; + return GraphBuilder::validate(args, 5); } CircleNode *CircleBCQFullyConnectedGraphBuilder::build_node(const circle::OperatorT &op, @@ -43,15 +40,6 @@ CircleNode *CircleBCQFullyConnectedGraphBuilder::build_node(const circle::Operat node->bias(inputs.at(3)); node->weights_clusters(inputs.at(4)); - // TODO Find and move to appropriate place for setting optional input - if (auto bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias())) - { - // bias is not used for type inference, but node itself should have a type - bias->dtype(loco::DataType::FLOAT32); - - // bias is not used for shape inference - } - const auto *options = op.builtin_options.AsBCQFullyConnectedOptions(); node->weights_hidden_size(options->weights_hidden_size); node->fusedActivationFunction(luci_actfunc(options->fused_activation_function)); diff --git a/compiler/luci/import/src/Nodes/CircleBCQGather.cpp b/compiler/luci/import/src/Nodes/CircleBCQGather.cpp index 464f1ac18..ee1358197 100644 --- a/compiler/luci/import/src/Nodes/CircleBCQGather.cpp +++ b/compiler/luci/import/src/Nodes/CircleBCQGather.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleBCQGatherGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 4) - return false; - - return true; + return GraphBuilder::validate(args, 4); } CircleNode *CircleBCQGatherGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp b/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp index 330775691..390719061 100644 --- a/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp +++ b/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp @@ -23,10 +23,7 @@ namespace luci bool CircleBatchMatMulGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleBatchMatMulGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp b/compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp new file mode 100644 index 000000000..f8bdcff72 --- /dev/null +++ b/compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h" + +#include <luci/IR/Nodes/CircleBidirectionalSequenceLSTM.h> +#include <luci/IR/Nodes/CircleBidirectionalSequenceLSTMOut.h> + +#include <loco.h> + +namespace luci +{ + +bool CircleBidirectionalSequenceLSTMGraphBuilder::validate(const ValidateArgs &args) const +{ + if (args.op.inputs.size() != 48) + return false; + if (args.op.outputs.size() != 2) + return false; + + return true; +} + +CircleNode *CircleBidirectionalSequenceLSTMGraphBuilder::build_node(const BuildNodeArgs &bna) const +{ + auto *node = bna.context->graph()->nodes()->create<CircleBidirectionalSequenceLSTM>(); + auto &inputs = bna.input_nodes; + node->input(inputs.at(0)); + node->fw_input_to_input_weights(inputs.at(1)); // Optional + node->fw_input_to_cell_weights(inputs.at(2)); + node->fw_input_to_forget_weights(inputs.at(3)); + node->fw_input_to_output_weights(inputs.at(4)); + node->fw_recurrent_to_input_weights(inputs.at(5)); // Optional + node->fw_recurrent_to_cell_weights(inputs.at(6)); + node->fw_recurrent_to_forget_weights(inputs.at(7)); + node->fw_recurrent_to_output_weights(inputs.at(8)); + node->fw_cell_to_input_weights(inputs.at(9)); // Optional + node->fw_cell_to_forget_weights(inputs.at(10)); // Optional + node->fw_cell_to_output_weights(inputs.at(11)); // Optional + node->fw_input_gate_bias(inputs.at(12)); // Optional + node->fw_forget_gate_bias(inputs.at(13)); + node->fw_cell_gate_bias(inputs.at(14)); + node->fw_output_gate_bias(inputs.at(15)); + node->fw_projection_weights(inputs.at(16)); // Optional + node->fw_projection_bias(inputs.at(17)); // Optional + node->bw_input_to_input_weights(inputs.at(18)); // Optional + node->bw_input_to_cell_weights(inputs.at(19)); + node->bw_input_to_forget_weights(inputs.at(20)); + node->bw_input_to_output_weights(inputs.at(21)); + node->bw_recurrent_to_input_weights(inputs.at(22)); // Optional + node->bw_recurrent_to_cell_weights(inputs.at(23)); + node->bw_recurrent_to_forget_weights(inputs.at(24)); + node->bw_recurrent_to_output_weights(inputs.at(25)); + node->bw_cell_to_input_weights(inputs.at(26)); // Optional + node->bw_cell_to_forget_weights(inputs.at(27)); // Optional + node->bw_cell_to_output_weights(inputs.at(28)); // Optional + node->bw_input_gate_bias(inputs.at(29)); // Optional + node->bw_forget_gate_bias(inputs.at(30)); + node->bw_cell_gate_bias(inputs.at(31)); + node->bw_output_gate_bias(inputs.at(32)); + node->bw_projection_weights(inputs.at(33)); // Optional + node->bw_projection_bias(inputs.at(34)); // Optional + node->fw_activation_state(inputs.at(35)); + node->fw_cell_state(inputs.at(36)); + node->bw_activation_state(inputs.at(37)); + node->bw_cell_state(inputs.at(38)); + + node->auxillary_input(inputs.at(39)); // Optional + node->fw_auxillary_input_to_input_weights(inputs.at(40)); // Optional + node->fw_auxillary_input_to_forget_weights(inputs.at(41)); // Optional + node->fw_auxillary_input_to_cell_weights(inputs.at(42)); // Optional + node->fw_auxillary_input_to_output_weights(inputs.at(43)); // Optional + node->bw_auxillary_input_to_input_weights(inputs.at(44)); // Optional + node->bw_auxillary_input_to_forget_weights(inputs.at(45)); // Optional + node->bw_auxillary_input_to_cell_weights(inputs.at(46)); // Optional + node->bw_auxillary_input_to_output_weights(inputs.at(47)); // Optional + + const auto *options = bna.op.builtin_options.AsBidirectionalSequenceLSTMOptions(); + node->fusedActivationFunction(luci_actfunc(options->fused_activation_function)); + node->cell_clip(options->cell_clip); + node->proj_clip(options->proj_clip); + node->merge_outputs(options->merge_outputs); + node->time_major(options->time_major); + node->asymmetric_quantize_inputs(options->asymmetric_quantize_inputs); + + return node; +} + +CircleNode *CircleBidirectionalSequenceLSTMGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleBidirectionalSequenceLSTMOut>(); + + nodeout->input(boa.node); + nodeout->index(boa.index); + + return nodeout; +} + +} // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleCast.cpp b/compiler/luci/import/src/Nodes/CircleCast.cpp index 7bdb63044..3e8c08bfa 100644 --- a/compiler/luci/import/src/Nodes/CircleCast.cpp +++ b/compiler/luci/import/src/Nodes/CircleCast.cpp @@ -30,14 +30,13 @@ bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const { LOGGER(l); + if (!GraphBuilder::validate(args, 1)) + return false; + auto settings = luci::UserSettings::settings(); const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - if (inputs.size() != 1) - return false; - if (outputs.size() != 1) - return false; // NOTE real models do have type mismatch const auto *options = args.op.builtin_options.AsCastOptions(); diff --git a/compiler/luci/import/src/Nodes/CircleCeil.cpp b/compiler/luci/import/src/Nodes/CircleCeil.cpp index 2e1aaa295..d439f41cd 100644 --- a/compiler/luci/import/src/Nodes/CircleCeil.cpp +++ b/compiler/luci/import/src/Nodes/CircleCeil.cpp @@ -25,16 +25,8 @@ namespace luci bool CircleCeilGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 1) - return false; - if (outputs.size() != 1) - return false; - // TODO dtype check - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleCeilGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleConv2D.cpp b/compiler/luci/import/src/Nodes/CircleConv2D.cpp index 9516ef16a..8cbecdc00 100644 --- a/compiler/luci/import/src/Nodes/CircleConv2D.cpp +++ b/compiler/luci/import/src/Nodes/CircleConv2D.cpp @@ -28,10 +28,7 @@ namespace luci bool CircleConv2DGraphBuilder::validate(const ValidateArgs &args) const { // Circle Conv2D may not have a bias but we won't support this - if (args.op.inputs.size() != 3) - return false; - - return true; + return GraphBuilder::validate(args, 3); } CircleNode *CircleConv2DGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleCos.cpp b/compiler/luci/import/src/Nodes/CircleCos.cpp index 27d60c62c..9705202ee 100644 --- a/compiler/luci/import/src/Nodes/CircleCos.cpp +++ b/compiler/luci/import/src/Nodes/CircleCos.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleCosGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleCosGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleCustom.cpp b/compiler/luci/import/src/Nodes/CircleCustom.cpp index d541ee87b..01ac3e2a0 100644 --- a/compiler/luci/import/src/Nodes/CircleCustom.cpp +++ b/compiler/luci/import/src/Nodes/CircleCustom.cpp @@ -27,62 +27,39 @@ bool CircleCustomGraphBuilder::validate(const ValidateArgs &) const return true; } -void CircleCustomGraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleCustomGraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); + uint32_t input_count = bna.op.inputs.size(); + uint32_t output_count = bna.op.outputs.size(); - auto graph = context->graph(); + auto *node = bna.context->graph()->nodes()->create<CircleCustom>(input_count, output_count); - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); + for (uint32_t idx = 0; idx < input_count; ++idx) + { + node->inputs(idx, bna.input_nodes[idx]); + } - // Create CircleCustom - const auto &opcodes = context->reader()->opcodes(); - const uint32_t opcode_index = op.opcode_index; + const auto &opcodes = bna.context->reader()->opcodes(); + const uint32_t opcode_index = bna.op.opcode_index; const circle::OperatorCodeT &opcode = *opcodes[opcode_index]; - auto *node = graph->nodes()->create<CircleCustom>(inputs.size()); - uint32_t input_idx = 0; - for (const int32_t input_tensor_index : inputs) - { - node->inputs(input_idx++, context->nodefinder()->node(input_tensor_index)); - } - node->custom_options(std::vector<uint8_t>{op.custom_options.begin(), op.custom_options.end()}); + node->custom_options( + std::vector<uint8_t>{bna.op.custom_options.begin(), bna.op.custom_options.end()}); node->custom_code(opcode.custom_code); - // Operator version of custom is always 1, so do nothing - uint32_t output_count = outputs.size(); + // NOTE Operator version of custom is always 1 - assert(output_count > 0); - { - // Let's use attributes from output 0 for this node - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->dtype(luci_datatype(output_tensor.type)); - } - - // Create virtual outputs of Custom - for (uint32_t n = 0; n < output_count; ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; + return node; +} - auto *nodeout = graph->nodes()->create<CircleCustomOut>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); +CircleNode *CircleCustomGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleCustomOut>(); - nodeout->input(node); - nodeout->index(n); + nodeout->input(boa.node); + nodeout->index(boa.index); - context->nodefinder()->enroll(outputs[n], nodeout); - } + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp index 49d31bb99..49eb30a83 100644 --- a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp +++ b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp @@ -27,17 +27,13 @@ namespace luci bool CircleDepthToSpaceGraphBuilder::validate(const ValidateArgs &args) const { + if (!GraphBuilder::validate(args, 1)) + return false; + const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; const auto *options = args.op.builtin_options.AsDepthToSpaceOptions(); - - if (inputs.size() != 1) - return false; - - if (outputs.size() != 1) - return false; - const auto &tensors = args.reader.tensors(); if (tensors[outputs[0]]->type != tensors[inputs.at(0)]->type) diff --git a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp index 53f85f2f5..727487c6a 100644 --- a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp +++ b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp @@ -32,6 +32,32 @@ bool CircleDepthwiseConv2DGraphBuilder::validate(const ValidateArgs &args) const if (args.op.outputs.size() != 1) return false; + const auto &tensors = args.reader.tensors(); + + // input shape + const auto &input = tensors.at(args.op.inputs.at(0)); + const auto &input_shape = input->shape; + + // input shape must be rank 4 + if (input_shape.size() != 4) + return false; + + // filter shape + const auto &filter = tensors.at(args.op.inputs.at(1)); + const auto &filter_shape = filter->shape; + + // filter shape must be rank 4 + if (filter_shape.size() != 4) + return false; + + // multiplier + const auto *options = args.op.builtin_options.AsDepthwiseConv2DOptions(); + const auto &multiplier = options->depth_multiplier; + + // filter represents as [1, H, W, C*M] where M is multiplier. + if (filter_shape.at(3) != input_shape.at(3) * multiplier) + return false; + return true; } diff --git a/compiler/luci/import/src/Nodes/CircleDequantize.cpp b/compiler/luci/import/src/Nodes/CircleDequantize.cpp index 1936da97c..3db546bd0 100644 --- a/compiler/luci/import/src/Nodes/CircleDequantize.cpp +++ b/compiler/luci/import/src/Nodes/CircleDequantize.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleDequantizeGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleDequantizeGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleDiv.cpp b/compiler/luci/import/src/Nodes/CircleDiv.cpp index 615c224d7..7ea1afd95 100644 --- a/compiler/luci/import/src/Nodes/CircleDiv.cpp +++ b/compiler/luci/import/src/Nodes/CircleDiv.cpp @@ -23,13 +23,7 @@ namespace luci bool CircleDivGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleDivGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleElu.cpp b/compiler/luci/import/src/Nodes/CircleElu.cpp index 919e95ee4..461da9517 100644 --- a/compiler/luci/import/src/Nodes/CircleElu.cpp +++ b/compiler/luci/import/src/Nodes/CircleElu.cpp @@ -25,14 +25,11 @@ namespace luci bool CircleEluGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; - if (outputs.size() != 1) - return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleEqual.cpp b/compiler/luci/import/src/Nodes/CircleEqual.cpp index 1db33b8ac..4909692b4 100644 --- a/compiler/luci/import/src/Nodes/CircleEqual.cpp +++ b/compiler/luci/import/src/Nodes/CircleEqual.cpp @@ -25,13 +25,10 @@ namespace luci bool CircleEqualGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); return tensors[inputs.at(0)]->type == tensors[inputs.at(1)]->type; diff --git a/compiler/luci/import/src/Nodes/CircleExp.cpp b/compiler/luci/import/src/Nodes/CircleExp.cpp index 2c031d6b3..64f18fbd4 100644 --- a/compiler/luci/import/src/Nodes/CircleExp.cpp +++ b/compiler/luci/import/src/Nodes/CircleExp.cpp @@ -25,10 +25,10 @@ namespace luci bool CircleExpGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; // input type check const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp index ab537c710..ee0fbdc7e 100644 --- a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp +++ b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp @@ -25,13 +25,10 @@ namespace luci bool CircleExpandDimsGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); return tensors[inputs.at(1)]->type == circle::TensorType_INT32; diff --git a/compiler/luci/import/src/Nodes/CircleFakeQuant.cpp b/compiler/luci/import/src/Nodes/CircleFakeQuant.cpp new file mode 100644 index 000000000..7cf40b225 --- /dev/null +++ b/compiler/luci/import/src/Nodes/CircleFakeQuant.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/Nodes/CircleFakeQuant.h" + +#include <luci/IR/Nodes/CircleFullyConnected.h> +#include <luci/IR/Nodes/CircleOutput.h> + +#include <loco.h> +#include <oops/UserExn.h> + +namespace luci +{ + +bool CircleFakeQuantGraphBuilder::validate(const ValidateArgs &args) const +{ + return GraphBuilder::validate(args, 1); +} + +CircleNode *CircleFakeQuantGraphBuilder::build_node(const circle::OperatorT &op, + const std::vector<CircleNode *> &inputs, + loco::Graph *graph) const +{ + auto *node = graph->nodes()->create<CircleFakeQuant>(); + node->inputs(inputs.at(0)); + + const auto *options = op.builtin_options.AsFakeQuantOptions(); + node->min(options->min); + node->max(options->max); + node->num_bits(options->num_bits); + node->narrow_range(options->narrow_range); + + return node; +} + +} // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleFill.cpp b/compiler/luci/import/src/Nodes/CircleFill.cpp index 95d5b876b..9aacddcbe 100644 --- a/compiler/luci/import/src/Nodes/CircleFill.cpp +++ b/compiler/luci/import/src/Nodes/CircleFill.cpp @@ -23,13 +23,7 @@ namespace luci bool CircleFillGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleFillGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleFloor.cpp b/compiler/luci/import/src/Nodes/CircleFloor.cpp index ce756b3b1..9651259c7 100644 --- a/compiler/luci/import/src/Nodes/CircleFloor.cpp +++ b/compiler/luci/import/src/Nodes/CircleFloor.cpp @@ -25,16 +25,8 @@ namespace luci bool CircleFloorGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 1) - return false; - if (outputs.size() != 1) - return false; - // TODO dtype check - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleFloorGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp index 55f385d60..ce329326a 100644 --- a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp +++ b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp @@ -25,19 +25,11 @@ namespace luci bool CircleFloorDivGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } - - if (outputs.size() != 1) - { - return false; - } + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor_in_0 = tensors.at(inputs.at(0)); const auto &tensor_in_1 = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp index 2101e417e..d8420a43c 100644 --- a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp +++ b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp @@ -25,13 +25,10 @@ namespace luci bool CircleFloorModGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 2) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); const auto &tensor_in_0 = tensors.at(inputs.at(0)); const auto &tensor_in_1 = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp index 17293ad7a..58750d79a 100644 --- a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp +++ b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp @@ -27,10 +27,7 @@ namespace luci bool CircleFullyConnectedGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 3) - return false; - - return true; + return GraphBuilder::validate(args, 3); } CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT &op, @@ -42,15 +39,6 @@ CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT node->weights(inputs.at(1)); node->bias(inputs.at(2)); // bias is optional - // TODO Find and move to appropriate place for setting optional input - if (auto bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias())) - { - // bias is not used for type inference, but node itself should have a type - bias->dtype(loco::DataType::FLOAT32); - - // bias is not used for shape inference - } - const auto *options = op.builtin_options.AsFullyConnectedOptions(); node->fusedActivationFunction(luci_actfunc(options->fused_activation_function)); node->weights_format(luci_weights_format(options->weights_format)); diff --git a/compiler/luci/import/src/Nodes/CircleGather.cpp b/compiler/luci/import/src/Nodes/CircleGather.cpp index 75447a38a..8317a3340 100644 --- a/compiler/luci/import/src/Nodes/CircleGather.cpp +++ b/compiler/luci/import/src/Nodes/CircleGather.cpp @@ -26,18 +26,14 @@ namespace luci bool CircleGatherGraphBuilder::validate(const ValidateArgs &args) const { + if (!GraphBuilder::validate(args, 2)) + return false; + const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; const auto *options = args.op.builtin_options.AsGatherOptions(); int32_t axis = options->axis; - if (inputs.size() != 2) - return false; - - if (outputs.size() != 1) - return false; - if (axis < 0) axis += inputs.size(); diff --git a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp index 981adbf63..a4bb26a10 100644 --- a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp +++ b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp @@ -27,15 +27,10 @@ namespace luci bool CircleGatherNdGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; auto &indices_tensor = args.reader.tensors()[inputs.at(1)]; if (!(indices_tensor->type == circle::TensorType::TensorType_INT32 || diff --git a/compiler/luci/import/src/Nodes/CircleGreater.cpp b/compiler/luci/import/src/Nodes/CircleGreater.cpp index 1ad0467e4..f9c00346c 100644 --- a/compiler/luci/import/src/Nodes/CircleGreater.cpp +++ b/compiler/luci/import/src/Nodes/CircleGreater.cpp @@ -30,17 +30,13 @@ bool CircleGreaterGraphBuilder::validate(const ValidateArgs &args) const { LOGGER(l); + if (!GraphBuilder::validate(args, 2)) + return false; + auto settings = luci::UserSettings::settings(); const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - - if (outputs.size() != 1) - return false; - const auto &tensors = args.reader.tensors(); if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type) diff --git a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp index 0ac63b017..e20038fd9 100644 --- a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp +++ b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp @@ -25,19 +25,11 @@ namespace luci bool CircleGreaterEqualGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } - - if (outputs.size() != 1) - { - return false; - } + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type) diff --git a/compiler/luci/import/src/Nodes/CircleIf.cpp b/compiler/luci/import/src/Nodes/CircleIf.cpp index db9ffe1cd..ffdbf0b79 100644 --- a/compiler/luci/import/src/Nodes/CircleIf.cpp +++ b/compiler/luci/import/src/Nodes/CircleIf.cpp @@ -70,69 +70,34 @@ bool CircleIfGraphBuilder::validate(const ValidateArgs &args) const * \- CircleIfOut --- Node --- */ -void CircleIfGraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const +CircleNode *CircleIfGraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); + uint32_t input_count = bna.op.inputs.size() - 1; + uint32_t output_count = bna.op.outputs.size(); - auto graph = context->graph(); + auto *node = bna.context->graph()->nodes()->create<CircleIf>(input_count, output_count); - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - uint32_t input_count = inputs.size() - 1; - uint32_t output_count = outputs.size(); - - // Create CircleIf - CircleIf *node = graph->nodes()->create<CircleIf>(input_count, output_count); - - node->cond(input_nodes[0]); + node->cond(bna.input_nodes[0]); for (uint32_t idx = 0; idx < input_count; ++idx) { - node->input(idx, input_nodes[idx + 1]); + node->input(idx, bna.input_nodes[idx + 1]); } - const auto *options = op.builtin_options.AsIfOptions(); + const auto *options = bna.op.builtin_options.AsIfOptions(); node->then_branch(options->then_subgraph_index); node->else_branch(options->else_subgraph_index); - assert(outputs.size() > 0); - { - // Lets use name of output 0 as If name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for If itself but to virtual outputs - } - - // Create virtual outputs of If - for (uint32_t n = 0; n < output_count; ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; + return node; +} - auto *nodeout = graph->nodes()->create<CircleIfOut>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); +CircleNode *CircleIfGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleIfOut>(); - nodeout->input(node); - nodeout->index(n); + nodeout->input(boa.node); + nodeout->index(boa.index); - context->nodefinder()->enroll(outputs[n], nodeout); - } + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp b/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp index 6349fd3b7..977b53406 100644 --- a/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp +++ b/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp @@ -25,12 +25,8 @@ namespace luci bool CircleInstanceNormGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 3) - return false; - // TODO check dtypes - - return true; + return GraphBuilder::validate(args, 3); } CircleNode *CircleInstanceNormGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp b/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp index e4fdc200c..7e1faedfb 100644 --- a/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp +++ b/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp @@ -25,20 +25,7 @@ namespace luci bool CircleL2NormalizeGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 1) - { - return false; - } - - if (outputs.size() != 1) - { - return false; - } - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleL2NormalizeGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp b/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp index 202d9d6fb..849c7c5ed 100644 --- a/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp +++ b/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp @@ -25,12 +25,8 @@ namespace luci bool CircleL2Pool2DGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - // TODO check dtypes - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleL2Pool2DGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp b/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp index ad4979f39..880fa6428 100644 --- a/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp +++ b/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp @@ -25,13 +25,7 @@ namespace luci bool CircleLeakyReluGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleLeakyReluGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleLess.cpp b/compiler/luci/import/src/Nodes/CircleLess.cpp index 506036908..f9b99bebe 100644 --- a/compiler/luci/import/src/Nodes/CircleLess.cpp +++ b/compiler/luci/import/src/Nodes/CircleLess.cpp @@ -25,19 +25,11 @@ namespace luci bool CircleLessGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } - - if (outputs.size() != 1) - { - return false; - } + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp index 9b4f934a5..bb1712137 100644 --- a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp +++ b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp @@ -25,19 +25,11 @@ namespace luci bool CircleLessEqualGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } - - if (outputs.size() != 1) - { - return false; - } + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type) diff --git a/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp b/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp index 0e32f62de..d03c47d12 100644 --- a/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp +++ b/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp @@ -25,16 +25,12 @@ namespace luci bool CircleLocalResponseNormalizationGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - // TODO do attribute checks - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleLocalResponseNormalizationGraphBuilder::build_node( - const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const + const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const { auto *node = graph->nodes()->create<CircleLocalResponseNormalization>(); node->input(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleLog.cpp b/compiler/luci/import/src/Nodes/CircleLog.cpp index 346fc43bb..26b575070 100644 --- a/compiler/luci/import/src/Nodes/CircleLog.cpp +++ b/compiler/luci/import/src/Nodes/CircleLog.cpp @@ -25,12 +25,10 @@ namespace luci bool CircleLogGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) - return false; - if (args.op.outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; // input type check // Must be one of bfloat16, half, float32, float64, complex64, complex128. // Currently circle supports half(float16), float32, float64, complex64. diff --git a/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp b/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp index ef69e868a..4361db691 100644 --- a/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp @@ -25,12 +25,8 @@ namespace luci bool CircleLogSoftmaxGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - // TODO do attribute checks - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleLogSoftmaxGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp index 7844da0f6..b13fc2735 100644 --- a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp @@ -25,11 +25,11 @@ namespace luci bool CircleLogicalAndGraphBuilder::validate(const ValidateArgs &args) const { - // Only BOOL type is allowed for inputs - const auto &inputs = args.op.inputs; - if (inputs.size() != 2) + if (!GraphBuilder::validate(args, 2)) return false; + // Only BOOL type is allowed for inputs + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); for (auto input : inputs) { diff --git a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp index 3758642e4..f68218349 100644 --- a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp @@ -25,7 +25,7 @@ namespace luci bool CircleLogicalNotGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; // Only BOOL type is allowed for the input diff --git a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp index 1b87e6f9c..8c9023dd3 100644 --- a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp @@ -25,7 +25,7 @@ namespace luci bool CircleLogicalOrGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) + if (!GraphBuilder::validate(args, 2)) return false; // Only BOOL type is allowed for inputs diff --git a/compiler/luci/import/src/Nodes/CircleLogistic.cpp b/compiler/luci/import/src/Nodes/CircleLogistic.cpp index 9606e19cd..0f92a9bb4 100644 --- a/compiler/luci/import/src/Nodes/CircleLogistic.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogistic.cpp @@ -25,13 +25,11 @@ namespace luci bool CircleLogisticGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) - return false; - const auto &outputs = args.op.outputs; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type) return false; diff --git a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp index a4a21a8b7..590a07f2d 100644 --- a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp +++ b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp @@ -25,15 +25,11 @@ namespace luci bool CircleMatrixDiagGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 1) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp index cf0313149..edd7d2ae2 100644 --- a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp +++ b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp @@ -25,15 +25,11 @@ namespace luci bool CircleMatrixSetDiagGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp b/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp index 4bca0f40b..5c03fff18 100644 --- a/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp +++ b/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleMaxPool2DGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleMaxPool2DGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleMean.cpp b/compiler/luci/import/src/Nodes/CircleMean.cpp index d8fa9a53d..7882f17fc 100644 --- a/compiler/luci/import/src/Nodes/CircleMean.cpp +++ b/compiler/luci/import/src/Nodes/CircleMean.cpp @@ -23,10 +23,7 @@ namespace luci bool CircleMeanGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleMeanGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp b/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp index e0ddd4c11..e40ce2249 100644 --- a/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp +++ b/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp @@ -25,12 +25,8 @@ namespace luci bool CircleMirrorPadGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - // TODO check others - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleMirrorPadGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleMul.cpp b/compiler/luci/import/src/Nodes/CircleMul.cpp index e3c4a7ee5..28421f8c4 100644 --- a/compiler/luci/import/src/Nodes/CircleMul.cpp +++ b/compiler/luci/import/src/Nodes/CircleMul.cpp @@ -23,13 +23,7 @@ namespace luci bool CircleMulGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleMulGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleNeg.cpp b/compiler/luci/import/src/Nodes/CircleNeg.cpp index a64a69560..9dd1458f4 100644 --- a/compiler/luci/import/src/Nodes/CircleNeg.cpp +++ b/compiler/luci/import/src/Nodes/CircleNeg.cpp @@ -24,11 +24,8 @@ namespace luci { bool CircleNegGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - // TODO Support type check - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleNegGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp index a4ad4a53d..d3d69506b 100644 --- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp +++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp @@ -61,63 +61,27 @@ bool CircleNonMaxSuppressionV4GraphBuilder::validate(const ValidateArgs &args) c * We will create multiple NonMasSuppressionV4Oout nodes to emulate this */ -void CircleNonMaxSuppressionV4GraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleNonMaxSuppressionV4GraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); - - auto graph = context->graph(); - - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleNonMaxSuppressionV4 - auto node = graph->nodes()->create<CircleNonMaxSuppressionV4>(); - node->boxes(input_nodes[0]); - node->scores(input_nodes[1]); - node->max_output_size(input_nodes[2]); - node->iou_threshold(input_nodes[3]); - node->score_threshold(input_nodes[4]); - - assert(outputs.size() == 2); - { - // Let's use name of output 0 as NonMaxSuppressionV4 name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for NonMaxSuppressionV4 itself but to virtual outputs - } - - // Create virtual outputs of NonMaxSuppressionV4 - for (size_t n = 0; n < outputs.size(); ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; - - auto *nodeout = graph->nodes()->create<CircleNonMaxSuppressionV4Out>(); - copy_tensor_attributes(output_tensor, nodeout); - - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); - - nodeout->input(node); - nodeout->index(n); - - context->nodefinder()->enroll(outputs[n], nodeout); - } + auto node = bna.context->graph()->nodes()->create<CircleNonMaxSuppressionV4>(); + + node->boxes(bna.input_nodes[0]); + node->scores(bna.input_nodes[1]); + node->max_output_size(bna.input_nodes[2]); + node->iou_threshold(bna.input_nodes[3]); + node->score_threshold(bna.input_nodes[4]); + + return node; +} + +CircleNode *CircleNonMaxSuppressionV4GraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleNonMaxSuppressionV4Out>(); + + nodeout->input(boa.node); + nodeout->index(boa.index); + + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp index 241dbf5ff..d797d4cb7 100644 --- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp +++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp @@ -63,64 +63,28 @@ bool CircleNonMaxSuppressionV5GraphBuilder::validate(const ValidateArgs &args) c * We will create multiple NonMasSuppressionV5Oout nodes to emulate this */ -void CircleNonMaxSuppressionV5GraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleNonMaxSuppressionV5GraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); - - auto graph = context->graph(); - - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleNonMaxSuppressionV5 - auto node = graph->nodes()->create<CircleNonMaxSuppressionV5>(); - node->boxes(input_nodes[0]); - node->scores(input_nodes[1]); - node->max_output_size(input_nodes[2]); - node->iou_threshold(input_nodes[3]); - node->score_threshold(input_nodes[4]); - node->soft_nms_sigma(input_nodes[5]); - - assert(outputs.size() == 3); - { - // Let's use name of output 0 as NonMaxSuppressionV5 name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for NonMaxSuppressionV5 itself but to virtual outputs - } - - // Create virtual outputs of NonMaxSuppressionV5 - for (size_t n = 0; n < outputs.size(); ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; - - auto *nodeout = graph->nodes()->create<CircleNonMaxSuppressionV5Out>(); - copy_tensor_attributes(output_tensor, nodeout); - - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); - - nodeout->input(node); - nodeout->index(n); - - context->nodefinder()->enroll(outputs[n], nodeout); - } + auto node = bna.context->graph()->nodes()->create<CircleNonMaxSuppressionV5>(); + + node->boxes(bna.input_nodes[0]); + node->scores(bna.input_nodes[1]); + node->max_output_size(bna.input_nodes[2]); + node->iou_threshold(bna.input_nodes[3]); + node->score_threshold(bna.input_nodes[4]); + node->soft_nms_sigma(bna.input_nodes[5]); + + return node; +} + +CircleNode *CircleNonMaxSuppressionV5GraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleNonMaxSuppressionV5Out>(); + + nodeout->input(boa.node); + nodeout->index(boa.index); + + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp index 77e986de1..a0b8f9e4f 100644 --- a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp +++ b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp @@ -25,19 +25,11 @@ namespace luci bool CircleNotEqualGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } - - if (outputs.size() != 1) - { - return false; - } + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type) diff --git a/compiler/luci/import/src/Nodes/CircleOneHot.cpp b/compiler/luci/import/src/Nodes/CircleOneHot.cpp index 69294e1ed..3952cc21a 100644 --- a/compiler/luci/import/src/Nodes/CircleOneHot.cpp +++ b/compiler/luci/import/src/Nodes/CircleOneHot.cpp @@ -26,17 +26,12 @@ namespace luci bool CircleOneHotGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - const auto *options = args.op.builtin_options.AsOneHotOptions(); - // Only 4 Input come refered from - if (inputs.size() != 4) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 4)) return false; + const auto &inputs = args.op.inputs; + const auto *options = args.op.builtin_options.AsOneHotOptions(); const auto &tensors = args.reader.tensors(); const auto &indices = tensors.at(inputs.at(0)); const auto &depth = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CirclePRelu.cpp b/compiler/luci/import/src/Nodes/CirclePRelu.cpp index c07920f7c..7c81f04bb 100644 --- a/compiler/luci/import/src/Nodes/CirclePRelu.cpp +++ b/compiler/luci/import/src/Nodes/CirclePRelu.cpp @@ -25,13 +25,7 @@ namespace luci bool CirclePReluGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CirclePReluGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CirclePad.cpp b/compiler/luci/import/src/Nodes/CirclePad.cpp index 999173b90..67dce6dee 100644 --- a/compiler/luci/import/src/Nodes/CirclePad.cpp +++ b/compiler/luci/import/src/Nodes/CirclePad.cpp @@ -25,12 +25,8 @@ namespace luci bool CirclePadGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - // TODO do attribute checks - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CirclePadGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CirclePadV2.cpp b/compiler/luci/import/src/Nodes/CirclePadV2.cpp index 493876e68..84a45722a 100644 --- a/compiler/luci/import/src/Nodes/CirclePadV2.cpp +++ b/compiler/luci/import/src/Nodes/CirclePadV2.cpp @@ -25,13 +25,7 @@ namespace luci bool CirclePadV2GraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 3) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 3); } CircleNode *CirclePadV2GraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CirclePow.cpp b/compiler/luci/import/src/Nodes/CirclePow.cpp index def012614..1d2d41607 100644 --- a/compiler/luci/import/src/Nodes/CirclePow.cpp +++ b/compiler/luci/import/src/Nodes/CirclePow.cpp @@ -25,13 +25,7 @@ namespace luci bool CirclePowGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CirclePowGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleRange.cpp b/compiler/luci/import/src/Nodes/CircleRange.cpp index 38dc44ed6..d3b5afc95 100644 --- a/compiler/luci/import/src/Nodes/CircleRange.cpp +++ b/compiler/luci/import/src/Nodes/CircleRange.cpp @@ -24,11 +24,8 @@ namespace luci { bool CircleRangeGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 3) - return false; - // TODO Support type check - return true; + return GraphBuilder::validate(args, 3); } CircleNode *CircleRangeGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleRank.cpp b/compiler/luci/import/src/Nodes/CircleRank.cpp index 12658b192..afebb9509 100644 --- a/compiler/luci/import/src/Nodes/CircleRank.cpp +++ b/compiler/luci/import/src/Nodes/CircleRank.cpp @@ -24,13 +24,7 @@ namespace luci { bool CircleRankGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleRankGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp index 21a821951..13205dd7a 100644 --- a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp +++ b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp @@ -23,13 +23,11 @@ namespace luci bool CircleReduceAnyGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 2) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor_0 = tensors.at(inputs.at(0)); const auto &tensor_1 = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp index 5f054586e..3549c1a18 100644 --- a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp +++ b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp @@ -23,12 +23,10 @@ namespace luci bool CircleReduceProdGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 2) - return false; - if (args.op.outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); const auto &tensor_1 = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CircleRelu.cpp b/compiler/luci/import/src/Nodes/CircleRelu.cpp index 8e1c32a3a..73b8ffee8 100644 --- a/compiler/luci/import/src/Nodes/CircleRelu.cpp +++ b/compiler/luci/import/src/Nodes/CircleRelu.cpp @@ -25,13 +25,7 @@ namespace luci bool CircleReluGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleReluGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleRelu6.cpp b/compiler/luci/import/src/Nodes/CircleRelu6.cpp index 0283d7350..ab957eda8 100644 --- a/compiler/luci/import/src/Nodes/CircleRelu6.cpp +++ b/compiler/luci/import/src/Nodes/CircleRelu6.cpp @@ -25,13 +25,7 @@ namespace luci bool CircleRelu6GraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleRelu6GraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp b/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp index 7f517bc0d..4987f3be2 100644 --- a/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp +++ b/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp @@ -25,15 +25,8 @@ namespace luci bool CircleReluN1To1GraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - // TODO check dtypes - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleReluN1To1GraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleReshape.cpp b/compiler/luci/import/src/Nodes/CircleReshape.cpp index 996ae9d20..401dff0fc 100644 --- a/compiler/luci/import/src/Nodes/CircleReshape.cpp +++ b/compiler/luci/import/src/Nodes/CircleReshape.cpp @@ -30,6 +30,19 @@ bool CircleReshapeGraphBuilder::validate(const ValidateArgs &args) const if (args.op.outputs.size() != 1) return false; + // for two inputs, check if type is S32 + if (args.op.inputs.size() == 2) + { + const auto &inputs = args.op.inputs; + const auto &tensors = args.reader.tensors(); + const auto &tensor_in = tensors.at(inputs.at(1)); + + // NOTE fix this if there is any other case + // TensorFlow lite and circle only supports S32 + if (tensor_in->type != circle::TensorType::TensorType_INT32) + return false; + } + return true; } @@ -53,6 +66,7 @@ static CircleNode *create_shape_node(const std::vector<int32_t> &shape, loco::Gr { shape_node->at<loco::DataType::S32>(i) = shape[i]; } + shape_node->name("Reshape/shape"); return shape_node; } @@ -73,6 +87,7 @@ CircleNode *CircleReshapeGraphBuilder::build_node(const circle::OperatorT &op, shape_node = graph->nodes()->create<CircleOutputDummy>(); shape_node->dtype(loco::DataType::S32); shape_node->rank(0); + shape_node->name("Reshape/dummy"); } } diff --git a/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp b/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp index 0fccb7b44..c751b245c 100644 --- a/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp +++ b/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp @@ -16,7 +16,6 @@ #include "luci/Import/Nodes/CircleResizeBilinear.h" -#include <luci/IR/Nodes/CircleConst.h> #include <luci/IR/Nodes/CircleResizeBilinear.h> namespace luci @@ -24,13 +23,7 @@ namespace luci bool CircleResizeBilinearGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleResizeBilinearGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp b/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp index 324323f59..df7517fe9 100644 --- a/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp +++ b/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp @@ -16,7 +16,6 @@ #include "luci/Import/Nodes/CircleResizeNearestNeighbor.h" -#include <luci/IR/Nodes/CircleConst.h> #include <luci/IR/Nodes/CircleResizeNearestNeighbor.h> namespace luci @@ -24,17 +23,11 @@ namespace luci bool CircleResizeNearestNeighborGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleResizeNearestNeighborGraphBuilder::build_node( - const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const + const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const { auto *node = graph->nodes()->create<CircleResizeNearestNeighbor>(); node->input(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp index ad11d4c63..2fbb7a87c 100644 --- a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp +++ b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp @@ -25,14 +25,11 @@ namespace luci bool CircleReverseSequenceGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor_in = tensors.at(inputs.at(0)); const auto &tensor_lengths = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp index e2e53bb4b..ca7653201 100644 --- a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp +++ b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp @@ -25,14 +25,11 @@ namespace luci bool CircleReverseV2GraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor_in = tensors.at(inputs.at(0)); const auto &tensor_axis = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CircleRound.cpp b/compiler/luci/import/src/Nodes/CircleRound.cpp index ad77f9f03..d13e0fafe 100644 --- a/compiler/luci/import/src/Nodes/CircleRound.cpp +++ b/compiler/luci/import/src/Nodes/CircleRound.cpp @@ -25,14 +25,11 @@ namespace luci bool CircleRoundGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 1) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; // Must be one of the following types // bfloat16, half (float16), float32, float64, complex64, complex128 // Currently, circle supports float16, float32, complex64 diff --git a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp index ae05fbbf9..a9ca90832 100644 --- a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp +++ b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp @@ -25,10 +25,10 @@ namespace luci bool CircleRsqrtGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; // Must be one of the following types // bfloat16, half (float16), float32, float64, complex64, complex128 // Currently, circle supports float16, float32, complex64 @@ -36,6 +36,8 @@ bool CircleRsqrtGraphBuilder::validate(const ValidateArgs &args) const const auto &tensor = tensors.at(inputs.at(0)); switch (tensor->type) { + case circle::TensorType_UINT8: + case circle::TensorType_INT16: case circle::TensorType_FLOAT16: case circle::TensorType_FLOAT32: case circle::TensorType_COMPLEX64: diff --git a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp index 7f86aeb74..f8c175110 100644 --- a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp +++ b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp @@ -25,10 +25,10 @@ namespace luci bool CircleScatterNdGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 3) + if (!GraphBuilder::validate(args, 3)) return false; + const auto &inputs = args.op.inputs; // indices must have the same type as shape const auto &tensors = args.reader.tensors(); diff --git a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp index fb84e5d52..bfa333e8d 100644 --- a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp +++ b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp @@ -25,13 +25,11 @@ namespace luci bool CircleSegmentSumGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 2) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor_in = tensors.at(inputs.at(0)); const auto &tensor_out = tensors.at(outputs[0]); diff --git a/compiler/luci/import/src/Nodes/CircleSelect.cpp b/compiler/luci/import/src/Nodes/CircleSelect.cpp index 1e649f1e0..36a5fa8a8 100644 --- a/compiler/luci/import/src/Nodes/CircleSelect.cpp +++ b/compiler/luci/import/src/Nodes/CircleSelect.cpp @@ -25,13 +25,10 @@ namespace luci bool CircleSelectGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 3) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 3)) return false; + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); if (tensor->type != circle::TensorType_BOOL) diff --git a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp index e6dd04de0..556c8fa33 100644 --- a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp +++ b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp @@ -25,13 +25,10 @@ namespace luci bool CircleSelectV2GraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 3) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 3)) return false; + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); const auto &condition = tensors.at(inputs.at(0)); if (condition->type != circle::TensorType_BOOL) diff --git a/compiler/luci/import/src/Nodes/CircleShape.cpp b/compiler/luci/import/src/Nodes/CircleShape.cpp index bd7dfc9d9..86c0bf59b 100644 --- a/compiler/luci/import/src/Nodes/CircleShape.cpp +++ b/compiler/luci/import/src/Nodes/CircleShape.cpp @@ -25,16 +25,8 @@ namespace luci bool CircleShapeGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 1) - return false; - if (outputs.size() != 1) - return false; - // TODO check shape, dtype - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleShapeGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleSin.cpp b/compiler/luci/import/src/Nodes/CircleSin.cpp index 4b245ef6b..22f461123 100644 --- a/compiler/luci/import/src/Nodes/CircleSin.cpp +++ b/compiler/luci/import/src/Nodes/CircleSin.cpp @@ -25,12 +25,10 @@ namespace luci bool CircleSinGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) - return false; - if (args.op.outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; // input type check const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleSlice.cpp b/compiler/luci/import/src/Nodes/CircleSlice.cpp index 8601fbf21..4166040b3 100644 --- a/compiler/luci/import/src/Nodes/CircleSlice.cpp +++ b/compiler/luci/import/src/Nodes/CircleSlice.cpp @@ -27,14 +27,8 @@ namespace luci bool CircleSliceGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 3) - return false; - if (args.op.outputs.size() != 1) - return false; - // TODO check shapes and types - - return true; + return GraphBuilder::validate(args, 3); } CircleNode *CircleSliceGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleSoftmax.cpp b/compiler/luci/import/src/Nodes/CircleSoftmax.cpp index 0ef0b5418..e79914455 100644 --- a/compiler/luci/import/src/Nodes/CircleSoftmax.cpp +++ b/compiler/luci/import/src/Nodes/CircleSoftmax.cpp @@ -25,12 +25,8 @@ namespace luci bool CircleSoftmaxGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - // TODO do attribute checks - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleSoftmaxGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp b/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp index 8ccd55dc6..2152b65c9 100644 --- a/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp +++ b/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp @@ -27,13 +27,8 @@ namespace luci bool CircleSpaceToDepthGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) - return false; - // TODO do attribute checks - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleSpaceToDepthGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp b/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp index ac756b1f3..ce0688bb9 100644 --- a/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp +++ b/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleSparseToDenseGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 4) - return false; - - return true; + return GraphBuilder::validate(args, 4); } CircleNode *CircleSparseToDenseGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleSplit.cpp b/compiler/luci/import/src/Nodes/CircleSplit.cpp index 07b6cc939..d0a24aae3 100644 --- a/compiler/luci/import/src/Nodes/CircleSplit.cpp +++ b/compiler/luci/import/src/Nodes/CircleSplit.cpp @@ -58,62 +58,27 @@ bool CircleSplitGraphBuilder::validate(const ValidateArgs &args) const * \- CircleSplitOut --- FullyConnected --- */ -void CircleSplitGraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const +CircleNode *CircleSplitGraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); + auto node = bna.context->graph()->nodes()->create<CircleSplit>(); - auto graph = context->graph(); + node->split_dim(bna.input_nodes[0]); + node->input(bna.input_nodes[1]); - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); + const auto *options = bna.op.builtin_options.AsSplitOptions(); + node->num_split(options->num_splits); - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } + return node; +} - // Create CircleSplit - auto node = graph->nodes()->create<CircleSplit>(); - node->split_dim(input_nodes[0]); - node->input(input_nodes[1]); +CircleNode *CircleSplitGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleSplitOut>(); - const auto *options = op.builtin_options.AsSplitOptions(); - node->num_split(options->num_splits); + nodeout->input(boa.node); + nodeout->index(boa.index); - assert(outputs.size() > 0); - assert(int32_t(outputs.size()) == options->num_splits); - { - // Let's use name of output 0 as Split name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for Split itself but to virtual outputs - } - - // Create virtual outputs of Split - for (int32_t n = 0; n < options->num_splits; ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; - - auto *nodeout = graph->nodes()->create<CircleSplitOut>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); - - nodeout->input(node); - nodeout->index(n); - - context->nodefinder()->enroll(outputs[n], nodeout); - } + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleSplitV.cpp b/compiler/luci/import/src/Nodes/CircleSplitV.cpp index 7c6e83e17..76cbf7046 100644 --- a/compiler/luci/import/src/Nodes/CircleSplitV.cpp +++ b/compiler/luci/import/src/Nodes/CircleSplitV.cpp @@ -58,64 +58,30 @@ bool CircleSplitVGraphBuilder::validate(const ValidateArgs &args) const * \- CircleSplitVOut --- FullyConnected --- */ -void CircleSplitVGraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleSplitVGraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); - - auto graph = context->graph(); - - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleSplitV - auto node = graph->nodes()->create<CircleSplitV>(); - node->input(input_nodes[0]); - node->size_splits(input_nodes[1]); - node->split_dim(input_nodes[2]); - - const auto *options = op.builtin_options.AsSplitVOptions(); + auto node = bna.context->graph()->nodes()->create<CircleSplitV>(); + + node->input(bna.input_nodes[0]); + node->size_splits(bna.input_nodes[1]); + node->split_dim(bna.input_nodes[2]); + + const auto *options = bna.op.builtin_options.AsSplitVOptions(); node->num_split(options->num_splits); - assert(outputs.size() > 0); - assert(int32_t(outputs.size()) == options->num_splits); - { - // Let's use name of output 0 as Split name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for Split itself but to virtual outputs - } - - // Create virtual outputs of Split - for (int32_t n = 0; n < options->num_splits; ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; - - auto *nodeout = graph->nodes()->create<CircleSplitVOut>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); - - nodeout->input(node); - nodeout->index(n); - - context->nodefinder()->enroll(outputs[n], nodeout); - } + assert(int32_t(bna.op.outputs.size()) == options->num_splits); + + return node; +} + +CircleNode *CircleSplitVGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleSplitVOut>(); + + nodeout->input(boa.node); + nodeout->index(boa.index); + + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleSqrt.cpp b/compiler/luci/import/src/Nodes/CircleSqrt.cpp index c8beaee0d..b1fdf7996 100644 --- a/compiler/luci/import/src/Nodes/CircleSqrt.cpp +++ b/compiler/luci/import/src/Nodes/CircleSqrt.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleSqrtGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleSqrtGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleSquare.cpp b/compiler/luci/import/src/Nodes/CircleSquare.cpp index b5ba048d7..7ff2b84e6 100644 --- a/compiler/luci/import/src/Nodes/CircleSquare.cpp +++ b/compiler/luci/import/src/Nodes/CircleSquare.cpp @@ -25,10 +25,10 @@ namespace luci bool CircleSquareGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; // Must be one of the following types // bfloat16, half (float16), float32, float64, complex64, complex128 // Currently, circle supports float16, float32, complex64 diff --git a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp index 6deae94c5..f4e193713 100644 --- a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp +++ b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp @@ -25,15 +25,11 @@ namespace luci bool CircleSquaredDifferenceGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; // Inputs must be one of the following types // bfloat16, half(float16), float32, float64, int32, int64, complex64, complex128 const auto &tensors = args.reader.tensors(); diff --git a/compiler/luci/import/src/Nodes/CircleSqueeze.cpp b/compiler/luci/import/src/Nodes/CircleSqueeze.cpp index 32792c266..d24d8166c 100644 --- a/compiler/luci/import/src/Nodes/CircleSqueeze.cpp +++ b/compiler/luci/import/src/Nodes/CircleSqueeze.cpp @@ -16,7 +16,6 @@ #include "luci/Import/Nodes/CircleSqueeze.h" -#include <luci/IR/Nodes/CircleConst.h> #include <luci/IR/Nodes/CircleSqueeze.h> namespace luci @@ -24,13 +23,7 @@ namespace luci bool CircleSqueezeGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleSqueezeGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp b/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp index 8f943a682..ca8259cac 100644 --- a/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp +++ b/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp @@ -27,14 +27,8 @@ namespace luci bool CircleStridedSliceGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 4) - return false; - if (args.op.outputs.size() != 1) - return false; - // TODO check shapes and types - - return true; + return GraphBuilder::validate(args, 4); } CircleNode *CircleStridedSliceGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleSub.cpp b/compiler/luci/import/src/Nodes/CircleSub.cpp index 9acf83d40..c3978f218 100644 --- a/compiler/luci/import/src/Nodes/CircleSub.cpp +++ b/compiler/luci/import/src/Nodes/CircleSub.cpp @@ -25,13 +25,7 @@ namespace luci bool CircleSubGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleSubGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleSum.cpp b/compiler/luci/import/src/Nodes/CircleSum.cpp index bd3cb6239..e348a62d9 100644 --- a/compiler/luci/import/src/Nodes/CircleSum.cpp +++ b/compiler/luci/import/src/Nodes/CircleSum.cpp @@ -23,10 +23,7 @@ namespace luci bool CircleSumGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleSumGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleTanh.cpp b/compiler/luci/import/src/Nodes/CircleTanh.cpp index 018f5701b..95625a0e4 100644 --- a/compiler/luci/import/src/Nodes/CircleTanh.cpp +++ b/compiler/luci/import/src/Nodes/CircleTanh.cpp @@ -25,13 +25,11 @@ namespace luci bool CircleTanhGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) - return false; - const auto &outputs = args.op.outputs; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type) return false; diff --git a/compiler/luci/import/src/Nodes/CircleTile.cpp b/compiler/luci/import/src/Nodes/CircleTile.cpp index bc6f320ba..6da44130c 100644 --- a/compiler/luci/import/src/Nodes/CircleTile.cpp +++ b/compiler/luci/import/src/Nodes/CircleTile.cpp @@ -25,15 +25,11 @@ namespace luci bool CircleTileGraphBuilder::validate(const ValidateArgs &args) const { - auto inputs = args.op.inputs; - auto outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + auto inputs = args.op.inputs; + auto outputs = args.op.outputs; // Multiples (inputs.at(1)) must be one of the following types // int32, int64 const auto &tensors = args.reader.tensors(); diff --git a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp index f0677de86..49f858798 100644 --- a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp +++ b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp @@ -59,59 +59,24 @@ bool CircleTopKV2GraphBuilder::validate(const ValidateArgs &args) const * \- CircleTopKV2Out --- FullyConnected --- */ -void CircleTopKV2GraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleTopKV2GraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); - - auto graph = context->graph(); - - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleTopKV2 - auto node = graph->nodes()->create<CircleTopKV2>(); - node->input(input_nodes[0]); - node->k(input_nodes[1]); - - assert(outputs.size() == 2); - { - // Let's use name of output 0 as TopKV2 name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for TopKV2 itself but to virtual outputs - } - - // Create virtual outputs of TopKV2 - for (size_t n = 0; n < outputs.size(); ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; - - auto *nodeout = graph->nodes()->create<CircleTopKV2Out>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); - - nodeout->input(node); - nodeout->index(n); - - context->nodefinder()->enroll(outputs[n], nodeout); - } + auto node = bna.context->graph()->nodes()->create<CircleTopKV2>(); + + node->input(bna.input_nodes[0]); + node->k(bna.input_nodes[1]); + + return node; +} + +CircleNode *CircleTopKV2GraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleTopKV2Out>(); + + nodeout->input(boa.node); + nodeout->index(boa.index); + + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleTranspose.cpp b/compiler/luci/import/src/Nodes/CircleTranspose.cpp index cc3153085..01095239e 100644 --- a/compiler/luci/import/src/Nodes/CircleTranspose.cpp +++ b/compiler/luci/import/src/Nodes/CircleTranspose.cpp @@ -25,13 +25,7 @@ namespace luci bool CircleTransposeGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleTransposeGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp index c280faaf5..5a60e2f54 100644 --- a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp +++ b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp @@ -61,16 +61,15 @@ CircleNode *CircleTransposeConvGraphBuilder::build_node(const circle::OperatorT node->filter(inputs.at(1)); node->outBackprop(inputs.at(2)); if (inputs.size() == 3) - node->bias(graph->nodes()->create<CircleOutputExclude>()); - else - node->bias(inputs.at(3)); - - if (auto bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias())) { - // CircleOutputExclude doesn't need a type, but since all nodes must have a type, a dummy type - // is inserted. + auto *bias = graph->nodes()->create<CircleOutputExclude>(); + // CircleOutputExclude doesn't need a type, but since all nodes must have a type, + // a dummy type is inserted. bias->dtype(loco::DataType::FLOAT32); + node->bias(bias); } + else + node->bias(inputs.at(3)); const auto *options = op.builtin_options.AsTransposeConvOptions(); node->padding(luci_padding(options->padding)); diff --git a/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp b/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp index c41cf4def..d9cc3f8d0 100644 --- a/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp +++ b/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp @@ -25,14 +25,11 @@ namespace luci bool CircleUnidirectionalSequenceLSTMGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 24) - return false; - - return true; + return GraphBuilder::validate(args, 24); } CircleNode *CircleUnidirectionalSequenceLSTMGraphBuilder::build_node( - const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const + const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const { auto *node = graph->nodes()->create<CircleUnidirectionalSequenceLSTM>(); node->input(inputs.at(0)); @@ -59,16 +56,6 @@ CircleNode *CircleUnidirectionalSequenceLSTMGraphBuilder::build_node( node->forget_layer_norm_coefficients(inputs.at(21)); // Optional node->cell_layer_norm_coefficients(inputs.at(22)); // Optional node->output_layer_norm_coefficients(inputs.at(23)); // Optional - const std::vector<int32_t> optionals = {1, 5, 9, 10, 11, 12, 16, 17, 20, 21, 22, 23}; - for (auto optional : optionals) - { - if (auto inp = dynamic_cast<luci::CircleOutputExclude *>(node->arg(optional))) - { - // CircleOutputExclude doesn't need a type, but since all nodes must have a type, a dummy type - // is inserted. - inp->dtype(loco::DataType::FLOAT32); - } - } const auto *options = op.builtin_options.AsUnidirectionalSequenceLSTMOptions(); node->fusedActivationFunction(luci_actfunc(options->fused_activation_function)); diff --git a/compiler/luci/import/src/Nodes/CircleUnique.cpp b/compiler/luci/import/src/Nodes/CircleUnique.cpp index 5e79a2920..f6914c24a 100644 --- a/compiler/luci/import/src/Nodes/CircleUnique.cpp +++ b/compiler/luci/import/src/Nodes/CircleUnique.cpp @@ -35,55 +35,26 @@ bool CircleUniqueGraphBuilder::validate(const ValidateArgs &args) const return true; } -void CircleUniqueGraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleUniqueGraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); + auto node = bna.context->graph()->nodes()->create<CircleUnique>(); - auto graph = context->graph(); + node->input(bna.input_nodes[0]); - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); + const auto *options = bna.op.builtin_options.AsUniqueOptions(); + node->idx_out_type(luci_datatype(options->idx_out_type)); - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleUnique - auto node = graph->nodes()->create<CircleUnique>(); - node->input(input_nodes[0]); - - const auto *options = op.builtin_options.AsUniqueOptions(); - node->output_type(luci_datatype(options->idx_out_type)); - - assert(int32_t(outputs.size()) == 2); - // Let's use name of output 0 as Unique name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - - // Create virtual outputs of Unique - for (int32_t n = 0; n < 2; ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; + return node; +} - auto *nodeout = graph->nodes()->create<CircleUniqueOut>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); +CircleNode *CircleUniqueGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleUniqueOut>(); - nodeout->input(node); - nodeout->index(n); + nodeout->input(boa.node); + nodeout->index(boa.index); - context->nodefinder()->enroll(outputs[n], nodeout); - } + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleUnpack.cpp b/compiler/luci/import/src/Nodes/CircleUnpack.cpp index 9e7f3d3e1..9bfc76b57 100644 --- a/compiler/luci/import/src/Nodes/CircleUnpack.cpp +++ b/compiler/luci/import/src/Nodes/CircleUnpack.cpp @@ -88,64 +88,27 @@ bool CircleUnpackGraphBuilder::validate(const ValidateArgs &args) const * \- CircleUnpackOut --- FullyConnected --- */ -void CircleUnpackGraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleUnpackGraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); + auto node = bna.context->graph()->nodes()->create<CircleUnpack>(); - auto graph = context->graph(); + node->value(bna.input_nodes[0]); - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - // NOTE Unpack has only one input so running a loop is not necessary - // This is provided as a reference for other Ops as a reference - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleUnpack - CircleUnpack *node = graph->nodes()->create<CircleUnpack>(); - node->value(input_nodes[0]); - - const auto *options = op.builtin_options.AsUnpackOptions(); + const auto *options = bna.op.builtin_options.AsUnpackOptions(); node->num(options->num); node->axis(options->axis); - assert(outputs.size() > 0); - { - // Let's use name of output 0 as Unpack name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for Unpack itself but to virtual outputs - } - - // Create virtual outputs of Unpack - for (int32_t n = 0; n < options->num; ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; + return node; +} - auto *nodeout = graph->nodes()->create<CircleUnpackOut>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); +CircleNode *CircleUnpackGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleUnpackOut>(); - nodeout->input(node); - nodeout->index(n); + nodeout->input(boa.node); + nodeout->index(boa.index); - context->nodefinder()->enroll(outputs[n], nodeout); - } + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleWhere.cpp b/compiler/luci/import/src/Nodes/CircleWhere.cpp index f4c5f0c66..8e4f1a0c4 100644 --- a/compiler/luci/import/src/Nodes/CircleWhere.cpp +++ b/compiler/luci/import/src/Nodes/CircleWhere.cpp @@ -25,15 +25,11 @@ namespace luci bool CircleWhereGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 1) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor_condition = tensors.at(inputs.at(0)); const auto &tensor_out = tensors.at(outputs[0]); diff --git a/compiler/luci/import/src/Nodes/CircleWhile.cpp b/compiler/luci/import/src/Nodes/CircleWhile.cpp index aead25071..26147562f 100644 --- a/compiler/luci/import/src/Nodes/CircleWhile.cpp +++ b/compiler/luci/import/src/Nodes/CircleWhile.cpp @@ -58,7 +58,8 @@ bool CircleWhileGraphBuilder::validate(const ValidateArgs &args) const * \- CircleWhileOut --- Node --- */ -void CircleWhileGraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const +CircleNode *CircleWhileGraphBuilder::build(const circle::OperatorT &op, + GraphBuilderContext *context) const { assert(context != nullptr); @@ -118,6 +119,8 @@ void CircleWhileGraphBuilder::build(const circle::OperatorT &op, GraphBuilderCon context->nodefinder()->enroll(outputs[n], nodeout); } + + return node; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleZerosLike.cpp b/compiler/luci/import/src/Nodes/CircleZerosLike.cpp index e60424def..ddb05e8a4 100644 --- a/compiler/luci/import/src/Nodes/CircleZerosLike.cpp +++ b/compiler/luci/import/src/Nodes/CircleZerosLike.cpp @@ -25,13 +25,7 @@ namespace luci bool CircleZerosLikeGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleZerosLikeGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/PostImport.cpp b/compiler/luci/import/src/PostImport.cpp index f436b48e8..63b16bb95 100644 --- a/compiler/luci/import/src/PostImport.cpp +++ b/compiler/luci/import/src/PostImport.cpp @@ -130,7 +130,10 @@ private: namespace { /** - * @brief ValidateNodeProp will validate inter graph connections for each Nodes + * @brief ValidateNodeProp will validate inter graph connections for each Nodes. + * @note In here, only loco::GraphInput and loco::GraphOutput are validated, + * since this class is for checking inter graph connections. + * CircleNodes such as CircleInput and CircleOutput will be validated at later steps. */ class ValidateNodeProp final : public luci::CircleNodeMutableVisitor<void> { @@ -172,9 +175,19 @@ public: auto then_graph_output = then_graph_outputs->at(then_out->index()); auto else_graph_output = else_graph_outputs->at(else_out->index()); - if (!(*then_graph_output->shape() == *else_graph_output->shape())) + if (then_graph_output->shape()->rank() != else_graph_output->shape()->rank()) { - INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output shape mismatch ", idx); + INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output rank mismatch ", idx); + } + for (uint32_t i = 0; i < then_graph_output->shape()->rank(); ++i) + { + if (then_graph_output->shape()->dim(i).known() && + else_graph_output->shape()->dim(i).known() && + then_graph_output->shape()->dim(i).value() != + else_graph_output->shape()->dim(i).value()) + { + INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output dimension mismatch ", idx); + } } if (then_graph_output->dtype() != else_graph_output->dtype()) { @@ -231,18 +244,20 @@ public: auto cond_graph_input = cond_graph_inputs->at(cond_in->index()); auto body_graph_input = body_graph_inputs->at(body_in->index()); - if ((cond_in->rank() != body_in->rank())) + if (cond_graph_input->shape()->rank() != body_graph_input->shape()->rank()) { - INTERNAL_EXN_V("CircleWhile COND input and BODY input shape mismatch ", idx); + INTERNAL_EXN_V("CircleWhile COND input and BODY input rank mismatch ", idx); } - if (cond_in->rank() > 0 && body_in->rank() > 0) + for (uint32_t i = 0; i < cond_graph_input->shape()->rank(); ++i) { - if (!(*cond_graph_input->shape() == *body_graph_input->shape())) + if (cond_graph_input->shape()->dim(i).known() && + body_graph_input->shape()->dim(i).known() && + cond_graph_input->shape()->dim(i).value() != body_graph_input->shape()->dim(i).value()) { - INTERNAL_EXN_V("CircleWhile COND input and BODY input shape mismatch ", idx); + INTERNAL_EXN_V("CircleWhile COND input and BODY input dimension mismatch ", idx); } } - if (cond_in->dtype() != body_in->dtype()) + if (cond_graph_input->dtype() != body_graph_input->dtype()) { INTERNAL_EXN_V("CircleWhile COND input and BODY input type mismatch ", idx); } @@ -257,18 +272,20 @@ public: auto cond_graph_input = cond_graph_inputs->at(cond_in->index()); auto body_graph_output = body_graph_outputs->at(body_out->index()); - if ((cond_in->rank() != body_out->rank())) + if (cond_graph_input->shape()->rank() != body_graph_output->shape()->rank()) { - INTERNAL_EXN_V("CircleWhile COND input and BODY output shape mismatch ", idx); + INTERNAL_EXN_V("CircleWhile COND input and BODY output rank mismatch ", idx); } - if (cond_in->rank() > 0 && body_out->rank() > 0) + for (uint32_t i = 0; i < cond_graph_input->shape()->rank(); ++i) { - if (!(*cond_graph_input->shape() == *body_graph_output->shape())) + if (cond_graph_input->shape()->dim(i).known() && + body_graph_output->shape()->dim(i).known() && + cond_graph_input->shape()->dim(i).value() != body_graph_output->shape()->dim(i).value()) { - INTERNAL_EXN_V("CircleWhile COND input and BODY output shape mismatch ", idx); + INTERNAL_EXN_V("CircleWhile COND input and BODY output dimension mismatch ", idx); } } - if (cond_in->dtype() != body_out->dtype()) + if (cond_graph_input->dtype() != body_graph_output->dtype()) { INTERNAL_EXN_V("CircleWhile COND input and BODY output type mismatch ", idx); } diff --git a/compiler/luci/lang/CMakeLists.txt b/compiler/luci/lang/CMakeLists.txt index 32d0a890d..c618fdd6f 100644 --- a/compiler/luci/lang/CMakeLists.txt +++ b/compiler/luci/lang/CMakeLists.txt @@ -7,6 +7,7 @@ target_include_directories(luci_lang PRIVATE src) target_include_directories(luci_lang PUBLIC include) target_link_libraries(luci_lang PUBLIC loco) target_link_libraries(luci_lang PUBLIC oops) +target_link_libraries(luci_lang PUBLIC nncc_coverage) target_link_libraries(luci_lang PRIVATE logo) target_link_libraries(luci_lang PRIVATE nncc_common) diff --git a/compiler/luci/lang/include/luci/IR/CircleNodeDecl.h b/compiler/luci/lang/include/luci/IR/CircleNodeDecl.h index e6410d154..edec9d18b 100644 --- a/compiler/luci/lang/include/luci/IR/CircleNodeDecl.h +++ b/compiler/luci/lang/include/luci/IR/CircleNodeDecl.h @@ -20,7 +20,6 @@ #include <loco/IR/Dialect.h> #include <loco/IR/Node.h> #include <loco/IR/NodeMixins.h> -#include <luci/IR/CircleShapeSignature.h> #include <luci/IR/PropertyShapeStatus.h> #include "CircleOpcode.h" @@ -62,9 +61,6 @@ struct CircleNode : public loco::Node, _sparsityparam = std::move(sparsityparam); } - const ShapeSignature &shape_signature(void) const { return _shape_signature; } - void shape_signature(const ShapeSignature &ss) { _shape_signature = ss; } - ShapeStatus shape_status(void) const { return _shape_status; } void shape_status(ShapeStatus ss) { _shape_status = ss; } @@ -75,7 +71,6 @@ private: NodeName _name; std::unique_ptr<CircleQuantParam> _quantparam; std::unique_ptr<SparsityParam> _sparsityparam; - ShapeSignature _shape_signature; ShapeStatus _shape_status{ShapeStatus::UNDEFINED}; int32_t _op_version = 1; }; diff --git a/compiler/luci/lang/include/luci/IR/CircleNodeImpl.h b/compiler/luci/lang/include/luci/IR/CircleNodeImpl.h index a6b9488db..4b3178b9b 100644 --- a/compiler/luci/lang/include/luci/IR/CircleNodeImpl.h +++ b/compiler/luci/lang/include/luci/IR/CircleNodeImpl.h @@ -34,8 +34,10 @@ template <typename T> T CircleNode::accept(CircleNodeVisitorBase<T> *v) const \ case CircleOpcode::OPCODE: \ return v->visit(dynamic_cast<const CLASS *>(this)); +#define CIRCLE_VNODE CIRCLE_NODE #include "CircleNodes.lst" +#undef CIRCLE_VNODE #undef CIRCLE_NODE default: @@ -53,8 +55,10 @@ template <typename T> T CircleNode::accept(CircleNodeMutableVisitorBase<T> *v) \ case CircleOpcode::OPCODE: \ return v->visit(dynamic_cast<CLASS *>(this)); +#define CIRCLE_VNODE CIRCLE_NODE #include "CircleNodes.lst" +#undef CIRCLE_VNODE #undef CIRCLE_NODE default: diff --git a/compiler/luci/lang/include/luci/IR/CircleNodeMixins.h b/compiler/luci/lang/include/luci/IR/CircleNodeMixins.h new file mode 100644 index 000000000..3f8ab7d61 --- /dev/null +++ b/compiler/luci/lang/include/luci/IR/CircleNodeMixins.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IR_CIRCLE_NODE_MIXINS_H__ +#define __LUCI_IR_CIRCLE_NODE_MIXINS_H__ + +#include "luci/IR/AttrFusedActFunc.h" + +#include <loco/IR/Node.h> +#include <loco/IR/NodeMixins.h> + +#include <vector> + +namespace luci +{ + +/// @brief enumeration of mixin class +enum class CircleNodeTrait +{ + FusedActFunc, + Bias +}; + +template <CircleNodeTrait T> class CircleNodeMixin; + +template <> class CircleNodeMixin<CircleNodeTrait::FusedActFunc> +{ +public: + CircleNodeMixin() = default; + +public: + FusedActFunc fusedActivationFunction() const { return _fused_act_fun; } + void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; } + +private: + FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED; +}; + +/** + * @brief Mixin class for nodes that has a bias input + */ +template <> class CircleNodeMixin<CircleNodeTrait::Bias> +{ +public: + CircleNodeMixin() = default; + +public: + virtual loco::Node *bias(void) const = 0; /// @brief get the input for bias. + virtual void bias(loco::Node *node) = 0; /// @brief set the input for bias. +}; + +/** + * @brief Nodes with the fixed number of inputs + * + * TODO Deprecated this class, and use loco::FixedArity instead + */ +template <unsigned N, typename Base> class FixedArityNode : public Base +{ +public: + FixedArityNode() + { + _args.resize(N); + for (uint32_t n = 0; n < N; ++n) + { + _args[n] = std::make_unique<loco::Use>(this); + } + } + + virtual ~FixedArityNode() = default; + +public: + unsigned arity(void) const final { return N; } + + loco::Node *arg(uint32_t n) const final { return _args.at(n)->node(); } + + void drop(void) final + { + for (uint32_t n = 0; n < N; ++n) + { + _args.at(n)->node(nullptr); + } + } + +protected: + // This API allows inherited classes to access "_args" field. + loco::Use *at(unsigned n) const { return _args.at(n).get(); } + +private: + std::vector<std::unique_ptr<loco::Use>> _args{}; +}; + +} // namespace luci + +#endif // __LUCI_IR_CIRCLE_NODE_MIXINS_H__ diff --git a/compiler/luci/lang/include/luci/IR/CircleNodeVisitor.h b/compiler/luci/lang/include/luci/IR/CircleNodeVisitor.h index 43339fe84..599e4bcd9 100644 --- a/compiler/luci/lang/include/luci/IR/CircleNodeVisitor.h +++ b/compiler/luci/lang/include/luci/IR/CircleNodeVisitor.h @@ -33,8 +33,10 @@ template <typename T> struct CircleNodeVisitorBase virtual ~CircleNodeVisitorBase() = default; #define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) virtual T visit(const CIRCLE_CLASS *) = 0; +#define CIRCLE_VNODE CIRCLE_NODE #include "CircleNodes.lst" +#undef CIRCLE_VNODE #undef CIRCLE_NODE }; @@ -44,9 +46,11 @@ template <typename T> struct CircleNodeVisitor : public CircleNodeVisitorBase<T> #define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \ virtual T visit(const CIRCLE_CLASS *node) { return visit(static_cast<const CircleNode *>(node)); } +#define CIRCLE_VNODE CIRCLE_NODE #include "CircleNodes.lst" +#undef CIRCLE_VNODE #undef CIRCLE_NODE /// @brief Default fallback @@ -61,9 +65,11 @@ template <typename T> struct CircleNodeMutableVisitorBase virtual ~CircleNodeMutableVisitorBase() = default; #define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) virtual T visit(CIRCLE_CLASS *) = 0; +#define CIRCLE_VNODE CIRCLE_NODE #include "CircleNodes.lst" +#undef CIRCLE_VNODE #undef CIRCLE_NODE }; @@ -73,9 +79,11 @@ template <typename T> struct CircleNodeMutableVisitor : public CircleNodeMutable #define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \ virtual T visit(CIRCLE_CLASS *node) { return visit(static_cast<CircleNode *>(node)); } +#define CIRCLE_VNODE CIRCLE_NODE #include "CircleNodes.lst" +#undef CIRCLE_VNODE #undef CIRCLE_NODE /// @brief Default fallback diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.h b/compiler/luci/lang/include/luci/IR/CircleNodes.h index fde0b612b..69a82a7b9 100644 --- a/compiler/luci/lang/include/luci/IR/CircleNodes.h +++ b/compiler/luci/lang/include/luci/IR/CircleNodes.h @@ -25,6 +25,7 @@ #include "Nodes/CircleAveragePool2D.h" #include "Nodes/CircleBatchMatMul.h" #include "Nodes/CircleBatchToSpaceND.h" +#include "Nodes/CircleBidirectionalSequenceLSTM.h" #include "Nodes/CircleCast.h" #include "Nodes/CircleCeil.h" #include "Nodes/CircleConcatenation.h" @@ -40,6 +41,7 @@ #include "Nodes/CircleEqual.h" #include "Nodes/CircleExp.h" #include "Nodes/CircleExpandDims.h" +#include "Nodes/CircleFakeQuant.h" #include "Nodes/CircleFill.h" #include "Nodes/CircleFloor.h" #include "Nodes/CircleFloorDiv.h" @@ -134,6 +136,7 @@ // Virtual nodes #include "Nodes/CircleInput.h" #include "Nodes/CircleOutput.h" +#include "Nodes/CircleBidirectionalSequenceLSTMOut.h" #include "Nodes/CircleCustomOut.h" #include "Nodes/CircleIfOut.h" #include "Nodes/CircleNonMaxSuppressionV4Out.h" @@ -150,15 +153,6 @@ namespace luci { -/** - * @brief Set both CircleReshape's 2nd input as CircleConst, and newShape attribute - * with same value - * @note Shape inference for TFLReshape forces them to be same - * - * TODO find better place for this helper - */ -void set_new_shape(CircleReshape *node, int32_t *base, uint32_t size); - /// @brief Link GraphOutput with CircleOutput node void link(loco::GraphOutput *, CircleOutput *); diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.lst b/compiler/luci/lang/include/luci/IR/CircleNodes.lst index b9d545893..b93fdc89d 100644 --- a/compiler/luci/lang/include/luci/IR/CircleNodes.lst +++ b/compiler/luci/lang/include/luci/IR/CircleNodes.lst @@ -2,6 +2,10 @@ #error "Define CIRCLE_NODE" #endif // CIRCLE_NODE +#ifndef CIRCLE_VNODE +#error "Define CIRCLE_VNODE" +#endif // CIRCLE_VNODE + // // PLEASE SORT NODE DECLS IN ALPHABETICAL ORDER // @@ -18,7 +22,8 @@ CIRCLE_NODE(ARG_MAX, luci::CircleArgMax) CIRCLE_NODE(ARG_MIN, luci::CircleArgMin) CIRCLE_NODE(AVERAGE_POOL_2D, luci::CircleAveragePool2D) CIRCLE_NODE(BATCH_TO_SPACE_ND, luci::CircleBatchToSpaceND) -CIRCLE_NODE(BATCHMATMUL, luci::CircleBatchMatMul) +CIRCLE_NODE(BATCH_MATMUL, luci::CircleBatchMatMul) +CIRCLE_NODE(BIDIRECTIONAL_SEQUENCE_LSTM, luci::CircleBidirectionalSequenceLSTM) CIRCLE_NODE(CAST, luci::CircleCast) CIRCLE_NODE(CEIL, luci::CircleCeil) CIRCLE_NODE(CONCATENATION, luci::CircleConcatenation) @@ -33,6 +38,7 @@ CIRCLE_NODE(ELU, luci::CircleElu) CIRCLE_NODE(EQUAL, luci::CircleEqual) CIRCLE_NODE(EXP, luci::CircleExp) CIRCLE_NODE(EXPAND_DIMS, luci::CircleExpandDims) +CIRCLE_NODE(FAKE_QUANT, luci::CircleFakeQuant) CIRCLE_NODE(FILL, luci::CircleFill) CIRCLE_NODE(FLOOR, luci::CircleFloor) CIRCLE_NODE(FLOOR_DIV, luci::CircleFloorDiv) @@ -125,18 +131,19 @@ CIRCLE_NODE(BCQ_FULLY_CONNECTED, luci::CircleBCQFullyConnected) CIRCLE_NODE(BCQ_GATHER, luci::CircleBCQGather) CIRCLE_NODE(INSTANCE_NORM, luci::CircleInstanceNorm) // Virtual node(s) -CIRCLE_NODE(CIRCLECONST, luci::CircleConst) -CIRCLE_NODE(CIRCLEINPUT, luci::CircleInput) -CIRCLE_NODE(CIRCLEOUTPUT, luci::CircleOutput) -CIRCLE_NODE(CIRCLEOUTPUTDUMMY, luci::CircleOutputDummy) -CIRCLE_NODE(CIRCLEOUTPUTEXCLUDE, luci::CircleOutputExclude) -CIRCLE_NODE(CIRCLECUSTOMOUT, luci::CircleCustomOut) -CIRCLE_NODE(CIRCLEIFOUT, luci::CircleIfOut) -CIRCLE_NODE(CIRCLENONMAXSUPPRESSIONV4OUT, luci::CircleNonMaxSuppressionV4Out) -CIRCLE_NODE(CIRCLENONMAXSUPPRESSIONV5OUT, luci::CircleNonMaxSuppressionV5Out) -CIRCLE_NODE(CIRCLESPLITOUT, luci::CircleSplitOut) -CIRCLE_NODE(CIRCLESPLITVOUT, luci::CircleSplitVOut) -CIRCLE_NODE(CIRCLETOPKV2OUT, luci::CircleTopKV2Out) -CIRCLE_NODE(CIRCLEUNIQUEOUT, luci::CircleUniqueOut) -CIRCLE_NODE(CIRCLEUNPACKOUT, luci::CircleUnpackOut) -CIRCLE_NODE(CIRCLEWHILEOUT, luci::CircleWhileOut) +CIRCLE_VNODE(CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT, luci::CircleBidirectionalSequenceLSTMOut) +CIRCLE_VNODE(CIRCLECONST, luci::CircleConst) +CIRCLE_VNODE(CIRCLEINPUT, luci::CircleInput) +CIRCLE_VNODE(CIRCLEOUTPUT, luci::CircleOutput) +CIRCLE_VNODE(CIRCLEOUTPUTDUMMY, luci::CircleOutputDummy) +CIRCLE_VNODE(CIRCLEOUTPUTEXCLUDE, luci::CircleOutputExclude) +CIRCLE_VNODE(CIRCLECUSTOMOUT, luci::CircleCustomOut) +CIRCLE_VNODE(CIRCLEIFOUT, luci::CircleIfOut) +CIRCLE_VNODE(CIRCLENONMAXSUPPRESSIONV4OUT, luci::CircleNonMaxSuppressionV4Out) +CIRCLE_VNODE(CIRCLENONMAXSUPPRESSIONV5OUT, luci::CircleNonMaxSuppressionV5Out) +CIRCLE_VNODE(CIRCLESPLITOUT, luci::CircleSplitOut) +CIRCLE_VNODE(CIRCLESPLITVOUT, luci::CircleSplitVOut) +CIRCLE_VNODE(CIRCLETOPKV2OUT, luci::CircleTopKV2Out) +CIRCLE_VNODE(CIRCLEUNIQUEOUT, luci::CircleUniqueOut) +CIRCLE_VNODE(CIRCLEUNPACKOUT, luci::CircleUnpackOut) +CIRCLE_VNODE(CIRCLEWHILEOUT, luci::CircleWhileOut) diff --git a/compiler/luci/lang/include/luci/IR/CircleOpcode.h b/compiler/luci/lang/include/luci/IR/CircleOpcode.h index 703b70da2..be3069f94 100644 --- a/compiler/luci/lang/include/luci/IR/CircleOpcode.h +++ b/compiler/luci/lang/include/luci/IR/CircleOpcode.h @@ -23,7 +23,9 @@ namespace luci enum class CircleOpcode { #define CIRCLE_NODE(OPCODE, CLASS) OPCODE, +#define CIRCLE_VNODE CIRCLE_NODE #include "CircleNodes.lst" +#undef CIRCLE_VNODE #undef CIRCLE_NODE }; diff --git a/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h b/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h deleted file mode 100644 index 18a260486..000000000 --- a/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __LUCI_IR_SHAPE_SIGNATURE_H__ -#define __LUCI_IR_SHAPE_SIGNATURE_H__ - -#include <stdint.h> -#include <vector> - -namespace luci -{ - -class ShapeSignature -{ -public: - ShapeSignature() = default; - - ShapeSignature(const std::vector<int32_t> &shape_signature) - { - _shape_signature = shape_signature; - } - -public: - const std::vector<int32_t> &as_vector() const { return _shape_signature; } - - int32_t dim(uint32_t d) const { return _shape_signature.at(d); } - int32_t &dim(uint32_t d) { return _shape_signature.at(d); } - - uint32_t rank(void) const { return _shape_signature.size(); } - void rank(uint32_t rank) { _shape_signature.resize(rank); } - -private: - std::vector<int32_t> _shape_signature{}; -}; - -bool operator==(const ShapeSignature &lhs, const ShapeSignature &rhs); - -} // namespace luci - -#endif // __LUCI_IR_SHAPE_SIGNATURE_H__ diff --git a/compiler/luci/lang/src/DeadNodeQueryService.h b/compiler/luci/lang/include/luci/IR/DeadNodeQueryService.h index d10696667..d10696667 100644 --- a/compiler/luci/lang/src/DeadNodeQueryService.h +++ b/compiler/luci/lang/include/luci/IR/DeadNodeQueryService.h diff --git a/compiler/luci/lang/include/luci/IR/LuciNodeMixins.h b/compiler/luci/lang/include/luci/IR/LuciNodeMixins.h index c1bb0db11..2078495c6 100644 --- a/compiler/luci/lang/include/luci/IR/LuciNodeMixins.h +++ b/compiler/luci/lang/include/luci/IR/LuciNodeMixins.h @@ -17,90 +17,16 @@ #ifndef __LUCI_IR_LUCINODEMIXINS_H__ #define __LUCI_IR_LUCINODEMIXINS_H__ -#include "luci/IR/AttrFusedActFunc.h" +// TODO remove this file after LuciNodeTrait and LuciNodeMixin are not used in backend -#include <loco/IR/Node.h> -#include <loco/IR/NodeMixins.h> - -#include <vector> +#include "luci/IR/CircleNodeMixins.h" namespace luci { -/// @brief enumeration of mixin class -enum class LuciNodeTrait -{ - FusedActFunc, - Bias -}; - -template <LuciNodeTrait T> class LuciNodeMixin; - -template <> class LuciNodeMixin<LuciNodeTrait::FusedActFunc> -{ -public: - LuciNodeMixin() = default; - -public: - FusedActFunc fusedActivationFunction() const { return _fused_act_fun; } - void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; } - -private: - FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED; -}; - -/** - * @brief Mixin class for nodes that has a bias input - */ -template <> class LuciNodeMixin<LuciNodeTrait::Bias> -{ -public: - LuciNodeMixin() = default; - -public: - virtual loco::Node *bias(void) const = 0; /// @brief get the input for bias. - virtual void bias(loco::Node *node) = 0; /// @brief set the input for bias. -}; - -/** - * @brief Nodes with the fixed number of inputs - * - * TODO Deprecated this class, and use loco::FixedArity instead - */ -template <unsigned N, typename Base> class FixedArityNode : public Base -{ -public: - FixedArityNode() - { - _args.resize(N); - for (uint32_t n = 0; n < N; ++n) - { - _args[n] = std::make_unique<loco::Use>(this); - } - } - - virtual ~FixedArityNode() = default; - -public: - unsigned arity(void) const final { return N; } - - loco::Node *arg(uint32_t n) const final { return _args.at(n)->node(); } - - void drop(void) final - { - for (uint32_t n = 0; n < N; ++n) - { - _args.at(n)->node(nullptr); - } - } - -protected: - // This API allows inherited classes to access "_args" field. - loco::Use *at(unsigned n) const { return _args.at(n).get(); } +using LuciNodeTrait = CircleNodeTrait; -private: - std::vector<std::unique_ptr<loco::Use>> _args{}; -}; +template <LuciNodeTrait T> using LuciNodeMixin = CircleNodeMixin<T>; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleAbs.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleAbs.h index 45dba15bf..7a73f37cd 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleAbs.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleAbs.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleAdd.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleAdd.h index f26eccd1a..92563de4c 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleAdd.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleAdd.h @@ -21,7 +21,7 @@ #include "luci/IR/CircleOpcode.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -30,7 +30,7 @@ namespace luci * @brief ADD in Circle */ class CircleAdd final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::ADD>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc> + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> { public: loco::Node *x(void) const { return at(0)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMax.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMax.h index dbc4b2b3a..c1e4631e4 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMax.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMax.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMin.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMin.h index 8cb561983..b4d026201 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMin.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMin.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleAveragePool2D.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleAveragePool2D.h index 0b43b40c8..4aa45c2d8 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleAveragePool2D.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleAveragePool2D.h @@ -24,7 +24,7 @@ #include "luci/IR/AttrPadding.h" #include "luci/IR/AttrStride.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -33,16 +33,14 @@ namespace luci * @brief AVERAGE_POOL_2D in Circle */ class CircleAveragePool2D final - : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::AVERAGE_POOL_2D>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc> + : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::AVERAGE_POOL_2D>>, + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> { public: - CircleAveragePool2D() : _padding(Padding::UNDEFINED) { /* empty */} - -public: loco::Node *value(void) const { return at(0)->node(); } void value(loco::Node *node) { at(0)->node(node); } +public: Padding padding() const { return _padding; } void padding(Padding padding) { _padding = padding; } @@ -53,7 +51,7 @@ public: Stride *stride(void) { return &_stride; } private: - Padding _padding; + Padding _padding{Padding::UNDEFINED}; Stride _stride; Filter _filter; }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQFullyConnected.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQFullyConnected.h index 7d12d593a..4c164ebca 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQFullyConnected.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQFullyConnected.h @@ -21,7 +21,7 @@ #include "luci/IR/CircleOpcode.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -30,9 +30,9 @@ namespace luci * @brief BCQ_FULLY_CONNECTED in Circle */ class CircleBCQFullyConnected final - : public FixedArityNode<5, CircleNodeImpl<CircleOpcode::BCQ_FULLY_CONNECTED>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc>, - public LuciNodeMixin<LuciNodeTrait::Bias> + : public FixedArityNode<5, CircleNodeImpl<CircleOpcode::BCQ_FULLY_CONNECTED>>, + public CircleNodeMixin<CircleNodeTrait::FusedActFunc>, + public CircleNodeMixin<CircleNodeTrait::Bias> { public: loco::Node *input(void) const { return at(0)->node(); } @@ -58,7 +58,7 @@ public: } private: - int32_t _weights_hidden_size = 0; + int32_t _weights_hidden_size{0}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQGather.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQGather.h index f7638261d..1a0bf4f19 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQGather.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQGather.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -51,8 +51,8 @@ public: void input_hidden_size(int32_t input_hidden_size) { _input_hidden_size = input_hidden_size; } private: - int32_t _axis = 0; - int32_t _input_hidden_size = 0; + int32_t _axis{0}; + int32_t _input_hidden_size{0}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchMatMul.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchMatMul.h index 19999924e..864b033ed 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchMatMul.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchMatMul.h @@ -20,15 +20,15 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { /** - * @brief BATCHMATMUL in Circle + * @brief BATCH_MATMUL in Circle */ -class CircleBatchMatMul final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::BATCHMATMUL>> +class CircleBatchMatMul final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::BATCH_MATMUL>> { public: loco::Node *x(void) const { return at(0)->node(); } @@ -45,8 +45,8 @@ public: void adj_y(bool arg) { _adj_y = arg; } private: - bool _adj_x = false; - bool _adj_y = false; + bool _adj_x{false}; + bool _adj_y{false}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchToSpaceND.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchToSpaceND.h index 67c0a2102..80fa53b8e 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchToSpaceND.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchToSpaceND.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,7 +29,7 @@ namespace luci * @brief BATCH_TO_SPACE_ND in Circle */ class CircleBatchToSpaceND final - : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::BATCH_TO_SPACE_ND>> + : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::BATCH_TO_SPACE_ND>> { public: loco::Node *input(void) const { return at(0)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTM.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTM.h new file mode 100644 index 000000000..d16281b69 --- /dev/null +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTM.h @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IR_CIRCLEBIDIRECTIONALSEQUENCE_LSTM_H__ +#define __LUCI_IR_CIRCLEBIDIRECTIONALSEQUENCE_LSTM_H__ + +#include "luci/IR/CircleNodeDecl.h" +#include "luci/IR/CircleOpcode.h" + +#include "luci/IR/AttrFusedActFunc.h" +#include "luci/IR/CircleNodeMixins.h" + +namespace luci +{ + +/** + * @brief BIDIRECTIONAL_SEQUENCE_LSTM in Circle + */ +class CircleBidirectionalSequenceLSTM final + : public FixedArityNode<48, CircleNodeImpl<CircleOpcode::BIDIRECTIONAL_SEQUENCE_LSTM>>, + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> +{ +public: + loco::Node *input(void) const { return at(0)->node(); } + void input(loco::Node *node) { at(0)->node(node); } + + loco::Node *fw_input_to_input_weights(void) const { return at(1)->node(); } + void fw_input_to_input_weights(loco::Node *node) { at(1)->node(node); } + loco::Node *fw_input_to_forget_weights(void) const { return at(2)->node(); } + void fw_input_to_forget_weights(loco::Node *node) { at(2)->node(node); } + loco::Node *fw_input_to_cell_weights(void) const { return at(3)->node(); } + void fw_input_to_cell_weights(loco::Node *node) { at(3)->node(node); } + loco::Node *fw_input_to_output_weights(void) const { return at(4)->node(); } + void fw_input_to_output_weights(loco::Node *node) { at(4)->node(node); } + + loco::Node *fw_recurrent_to_input_weights(void) const { return at(5)->node(); } + void fw_recurrent_to_input_weights(loco::Node *node) { at(5)->node(node); } + loco::Node *fw_recurrent_to_forget_weights(void) const { return at(6)->node(); } + void fw_recurrent_to_forget_weights(loco::Node *node) { at(6)->node(node); } + loco::Node *fw_recurrent_to_cell_weights(void) const { return at(7)->node(); } + void fw_recurrent_to_cell_weights(loco::Node *node) { at(7)->node(node); } + loco::Node *fw_recurrent_to_output_weights(void) const { return at(8)->node(); } + void fw_recurrent_to_output_weights(loco::Node *node) { at(8)->node(node); } + + loco::Node *fw_cell_to_input_weights(void) const { return at(9)->node(); } + void fw_cell_to_input_weights(loco::Node *node) { at(9)->node(node); } + loco::Node *fw_cell_to_forget_weights(void) const { return at(10)->node(); } + void fw_cell_to_forget_weights(loco::Node *node) { at(10)->node(node); } + loco::Node *fw_cell_to_output_weights(void) const { return at(11)->node(); } + void fw_cell_to_output_weights(loco::Node *node) { at(11)->node(node); } + + loco::Node *fw_input_gate_bias(void) const { return at(12)->node(); } + void fw_input_gate_bias(loco::Node *node) { at(12)->node(node); } + loco::Node *fw_forget_gate_bias(void) const { return at(13)->node(); } + void fw_forget_gate_bias(loco::Node *node) { at(13)->node(node); } + loco::Node *fw_cell_gate_bias(void) const { return at(14)->node(); } + void fw_cell_gate_bias(loco::Node *node) { at(14)->node(node); } + loco::Node *fw_output_gate_bias(void) const { return at(15)->node(); } + void fw_output_gate_bias(loco::Node *node) { at(15)->node(node); } + + loco::Node *fw_projection_weights(void) const { return at(16)->node(); } + void fw_projection_weights(loco::Node *node) { at(16)->node(node); } + loco::Node *fw_projection_bias(void) const { return at(17)->node(); } + void fw_projection_bias(loco::Node *node) { at(17)->node(node); } + + loco::Node *bw_input_to_input_weights(void) const { return at(18)->node(); } + void bw_input_to_input_weights(loco::Node *node) { at(18)->node(node); } + loco::Node *bw_input_to_forget_weights(void) const { return at(19)->node(); } + void bw_input_to_forget_weights(loco::Node *node) { at(19)->node(node); } + loco::Node *bw_input_to_cell_weights(void) const { return at(20)->node(); } + void bw_input_to_cell_weights(loco::Node *node) { at(20)->node(node); } + loco::Node *bw_input_to_output_weights(void) const { return at(21)->node(); } + void bw_input_to_output_weights(loco::Node *node) { at(21)->node(node); } + + loco::Node *bw_recurrent_to_input_weights(void) const { return at(22)->node(); } + void bw_recurrent_to_input_weights(loco::Node *node) { at(22)->node(node); } + loco::Node *bw_recurrent_to_forget_weights(void) const { return at(23)->node(); } + void bw_recurrent_to_forget_weights(loco::Node *node) { at(23)->node(node); } + loco::Node *bw_recurrent_to_cell_weights(void) const { return at(24)->node(); } + void bw_recurrent_to_cell_weights(loco::Node *node) { at(24)->node(node); } + loco::Node *bw_recurrent_to_output_weights(void) const { return at(25)->node(); } + void bw_recurrent_to_output_weights(loco::Node *node) { at(25)->node(node); } + + loco::Node *bw_cell_to_input_weights(void) const { return at(26)->node(); } + void bw_cell_to_input_weights(loco::Node *node) { at(26)->node(node); } + loco::Node *bw_cell_to_forget_weights(void) const { return at(27)->node(); } + void bw_cell_to_forget_weights(loco::Node *node) { at(27)->node(node); } + loco::Node *bw_cell_to_output_weights(void) const { return at(28)->node(); } + void bw_cell_to_output_weights(loco::Node *node) { at(28)->node(node); } + + loco::Node *bw_input_gate_bias(void) const { return at(29)->node(); } + void bw_input_gate_bias(loco::Node *node) { at(29)->node(node); } + loco::Node *bw_forget_gate_bias(void) const { return at(30)->node(); } + void bw_forget_gate_bias(loco::Node *node) { at(30)->node(node); } + loco::Node *bw_cell_gate_bias(void) const { return at(31)->node(); } + void bw_cell_gate_bias(loco::Node *node) { at(31)->node(node); } + loco::Node *bw_output_gate_bias(void) const { return at(32)->node(); } + void bw_output_gate_bias(loco::Node *node) { at(32)->node(node); } + + loco::Node *bw_projection_weights(void) const { return at(33)->node(); } + void bw_projection_weights(loco::Node *node) { at(33)->node(node); } + loco::Node *bw_projection_bias(void) const { return at(34)->node(); } + void bw_projection_bias(loco::Node *node) { at(34)->node(node); } + + loco::Node *fw_activation_state(void) const { return at(35)->node(); } + void fw_activation_state(loco::Node *node) { at(35)->node(node); } + loco::Node *fw_cell_state(void) const { return at(36)->node(); } + void fw_cell_state(loco::Node *node) { at(36)->node(node); } + + loco::Node *bw_activation_state(void) const { return at(37)->node(); } + void bw_activation_state(loco::Node *node) { at(37)->node(node); } + loco::Node *bw_cell_state(void) const { return at(38)->node(); } + void bw_cell_state(loco::Node *node) { at(38)->node(node); } + + loco::Node *auxillary_input(void) const { return at(39)->node(); } + void auxillary_input(loco::Node *node) { at(39)->node(node); } + loco::Node *fw_auxillary_input_to_input_weights(void) const { return at(40)->node(); } + void fw_auxillary_input_to_input_weights(loco::Node *node) { at(40)->node(node); } + loco::Node *fw_auxillary_input_to_forget_weights(void) const { return at(41)->node(); } + void fw_auxillary_input_to_forget_weights(loco::Node *node) { at(41)->node(node); } + loco::Node *fw_auxillary_input_to_cell_weights(void) const { return at(42)->node(); } + void fw_auxillary_input_to_cell_weights(loco::Node *node) { at(42)->node(node); } + loco::Node *fw_auxillary_input_to_output_weights(void) const { return at(43)->node(); } + void fw_auxillary_input_to_output_weights(loco::Node *node) { at(43)->node(node); } + loco::Node *bw_auxillary_input_to_input_weights(void) const { return at(44)->node(); } + void bw_auxillary_input_to_input_weights(loco::Node *node) { at(44)->node(node); } + loco::Node *bw_auxillary_input_to_forget_weights(void) const { return at(45)->node(); } + void bw_auxillary_input_to_forget_weights(loco::Node *node) { at(45)->node(node); } + loco::Node *bw_auxillary_input_to_cell_weights(void) const { return at(46)->node(); } + void bw_auxillary_input_to_cell_weights(loco::Node *node) { at(46)->node(node); } + loco::Node *bw_auxillary_input_to_output_weights(void) const { return at(47)->node(); } + void bw_auxillary_input_to_output_weights(loco::Node *node) { at(47)->node(node); } + +public: + float cell_clip(void) const { return _cell_clip; } + void cell_clip(float cell_clip) { _cell_clip = cell_clip; } + float proj_clip(void) const { return _proj_clip; } + void proj_clip(float proj_clip) { _proj_clip = proj_clip; } + bool merge_outputs(void) const { return _merge_outputs; } + void merge_outputs(bool merge_outputs) { _merge_outputs = merge_outputs; } + bool time_major(void) const { return _time_major; } + void time_major(bool time_major) { _time_major = time_major; } + bool asymmetric_quantize_inputs(void) const { return _asymmetric_quantize_inputs; } + void asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) + { + _asymmetric_quantize_inputs = asymmetric_quantize_inputs; + } + +private: + float _cell_clip{0.0f}; + float _proj_clip{0.0f}; + bool _merge_outputs{false}; + bool _time_major{false}; + bool _asymmetric_quantize_inputs{false}; +}; + +} // namespace luci + +#endif // __LUCI_IR_CIRCLEBIDIRECTIONALSEQUENCE_LSTM_H__ diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTMOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTMOut.h new file mode 100644 index 000000000..fb2eb0831 --- /dev/null +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTMOut.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IR_CIRCLE_BIDIRECTIONAL_SEQUENCE_LSTM_OUT_H__ +#define __LUCI_IR_CIRCLE_BIDIRECTIONAL_SEQUENCE_LSTM_OUT_H__ + +#include "luci/IR/CircleNodeDecl.h" +#include "luci/IR/CircleOpcode.h" + +#include "luci/IR/CircleNodeMixins.h" + +namespace luci +{ + +/** + * @brief Virtual CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT in Circle + */ +class CircleBidirectionalSequenceLSTMOut final + : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT>> +{ +public: + loco::Node *input(void) const { return at(0)->node(); } + void input(loco::Node *node) { at(0)->node(node); } + +public: + int32_t index(void) const { return _index; } + void index(int32_t index) { _index = index; } + +private: + int32_t _index{-1}; +}; + +} // namespace luci + +#endif // __LUCI_IR_CIRCLE_BIDIRECTIONAL_SEQUENCE_LSTM_OUT_H__ diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleCast.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleCast.h index 9a89d0b2b..0b793607f 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleCast.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleCast.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleCeil.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleCeil.h index 8a8715dcf..3d7a7ebc7 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleCeil.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleCeil.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleConcatenation.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleConcatenation.h index dea1a4613..2746a0a2e 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleConcatenation.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleConcatenation.h @@ -21,7 +21,7 @@ #include "luci/IR/CircleOpcode.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" #include "luci/IR/VariadicArityNode.h" #include <cassert> @@ -33,12 +33,12 @@ namespace luci * @brief CONCATENATION in Circle */ class CircleConcatenation final - : public VariadicArityNode<CircleNodeImpl<CircleOpcode::CONCATENATION>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc> + : public VariadicArityNode<CircleNodeImpl<CircleOpcode::CONCATENATION>>, + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> { public: CircleConcatenation(uint32_t arity) - : VariadicArityNode<CircleNodeImpl<CircleOpcode::CONCATENATION>>(arity) + : VariadicArityNode<CircleNodeImpl<CircleOpcode::CONCATENATION>>(arity) { // TODO Support when arity is 0 assert(arity >= 1); diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleConst.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleConst.h index 250282049..e44363d14 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleConst.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleConst.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" #include <loco/IR/DataTypeTraits.h> @@ -34,9 +34,6 @@ namespace luci class CircleConst final : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLECONST>> { public: - CircleConst() = default; - -public: template <loco::DataType DT> uint32_t size(void) const; template <loco::DataType DT> void size(uint32_t size); template <loco::DataType DT> const typename loco::DataTypeImpl<DT>::Type &at(uint32_t n) const; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleConv2D.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleConv2D.h index 13657cee4..7c390940e 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleConv2D.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleConv2D.h @@ -24,7 +24,7 @@ #include "luci/IR/AttrStride.h" #include "luci/IR/AttrDilation.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -33,8 +33,8 @@ namespace luci * @brief CONV_2D in Circle */ class CircleConv2D final : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::CONV_2D>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc>, - public LuciNodeMixin<LuciNodeTrait::Bias> + public CircleNodeMixin<CircleNodeTrait::FusedActFunc>, + public CircleNodeMixin<CircleNodeTrait::Bias> { public: loco::Node *input(void) const { return at(0)->node(); } @@ -57,7 +57,7 @@ public: Dilation *dilation(void) { return &_dilation; } private: - Padding _padding = Padding::UNDEFINED; + Padding _padding{Padding::UNDEFINED}; Stride _stride; Dilation _dilation; }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleCos.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleCos.h index 07ced620a..cff04906d 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleCos.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleCos.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleCustom.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleCustom.h index 6c722b766..b21cc679f 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleCustom.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleCustom.h @@ -29,19 +29,23 @@ namespace luci class CircleCustom final : public VariadicArityNode<CircleNodeImpl<CircleOpcode::CUSTOM>> { public: - CircleCustom(uint32_t arity) : VariadicArityNode<CircleNodeImpl<CircleOpcode::CUSTOM>>(arity) + CircleCustom(uint32_t arity, uint32_t out) + : VariadicArityNode<CircleNodeImpl<CircleOpcode::CUSTOM>>(arity), _output_count(out) { // TODO Support when arity is 0 assert(arity >= 1); + assert(out > 0); } public: uint32_t numInputs(void) const { return arity(); } + uint32_t numOutputs(void) const { return _output_count; } public: Node *inputs(uint32_t index) const { return at(index)->node(); } void inputs(uint32_t index, Node *node) { at(index)->node(node); } +public: const std::vector<uint8_t> &custom_options(void) const { return _custom_options; } void custom_options(const std::vector<uint8_t> &custom_options) { @@ -54,6 +58,7 @@ public: private: std::vector<uint8_t> _custom_options; std::string _custom_code; + uint32_t _output_count{0}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleCustomOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleCustomOut.h index 36b8e4aed..91a89c151 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleCustomOut.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleCustomOut.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,12 +29,9 @@ namespace luci * @brief Virtual CIRCLECUSTOMOUT in Circle */ class CircleCustomOut final - : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLECUSTOMOUT>> + : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLECUSTOMOUT>> { public: - CircleCustomOut() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthToSpace.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthToSpace.h index e19282b97..85b567fb7 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthToSpace.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthToSpace.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,18 +29,18 @@ namespace luci * @brief DEPTH_TO_SPACE in Circle */ class CircleDepthToSpace final - : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::DEPTH_TO_SPACE>> + : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::DEPTH_TO_SPACE>> { public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } public: - int block_size(void) const { return _block_size; } - void block_size(int block_size) { _block_size = block_size; } + int32_t block_size(void) const { return _block_size; } + void block_size(int32_t block_size) { _block_size = block_size; } private: - int _block_size{0}; + int32_t _block_size{0}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthwiseConv2D.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthwiseConv2D.h index eb058cec1..046aa5908 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthwiseConv2D.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthwiseConv2D.h @@ -25,7 +25,7 @@ #include "luci/IR/AttrPadding.h" #include "luci/IR/AttrStride.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -34,9 +34,9 @@ namespace luci * @brief DEPTHWISE_CONV_2D in Circle */ class CircleDepthwiseConv2D final - : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::DEPTHWISE_CONV_2D>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc>, - public LuciNodeMixin<LuciNodeTrait::Bias> + : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::DEPTHWISE_CONV_2D>>, + public CircleNodeMixin<CircleNodeTrait::FusedActFunc>, + public CircleNodeMixin<CircleNodeTrait::Bias> { public: loco::Node *input(void) const { return at(0)->node(); } @@ -62,9 +62,9 @@ public: Dilation *dilation(void) { return &_dilation; } private: - Padding _padding = Padding::UNDEFINED; + Padding _padding{Padding::UNDEFINED}; Stride _stride; - int32_t _depth_multiplier = 0; + int32_t _depth_multiplier{0}; Dilation _dilation; }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleDequantize.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleDequantize.h index 847c5dfc5..c3ee44253 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleDequantize.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleDequantize.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleDiv.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleDiv.h index 1d4d3a239..fcc3f427c 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleDiv.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleDiv.h @@ -24,7 +24,7 @@ #include "luci/IR/AttrPadding.h" #include "luci/IR/AttrStride.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -33,12 +33,9 @@ namespace luci * @brief DIV in Circle */ class CircleDiv final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::DIV>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc> + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> { public: - CircleDiv() = default; - -public: loco::Node *x(void) const { return at(0)->node(); } void x(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleElu.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleElu.h index fbb2f3533..721edd9ae 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleElu.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleElu.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleElu final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::ELU>> { public: - CircleElu() = default; - -public: loco::Node *features(void) const { return at(0)->node(); } void features(loco::Node *node) { at(0)->node(node); } }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleEqual.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleEqual.h index 2087d097a..69697ac7e 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleEqual.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleEqual.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleExp.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleExp.h index 97aecb30a..b8a5d4561 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleExp.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleExp.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleExpandDims.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleExpandDims.h index f70219614..15bfe6a29 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleExpandDims.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleExpandDims.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleExpandDims final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::EXPAND_DIMS>> { public: - CircleExpandDims() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFakeQuant.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFakeQuant.h new file mode 100644 index 000000000..9e3159685 --- /dev/null +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFakeQuant.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IR_CIRCLE_FAKE_QUANT_H__ +#define __LUCI_IR_CIRCLE_FAKE_QUANT_H__ + +#include "luci/IR/CircleNodeDecl.h" +#include "luci/IR/CircleOpcode.h" + +#include "luci/IR/CircleNodeMixins.h" + +namespace luci +{ + +/** + * @brief FAKE_QUANT in Circle + * @note 'inputs' came from TF.quantize.fake_quant_from_min_max_vars + */ +class CircleFakeQuant final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::FAKE_QUANT>> +{ +public: + loco::Node *inputs(void) const { return at(0)->node(); } + void inputs(loco::Node *node) { at(0)->node(node); } + +public: + float min(void) const { return _min; } + void min(float min) { _min = min; } + + float max(void) const { return _max; } + void max(float max) { _max = max; } + + int32_t num_bits(void) const { return _num_bits; } + void num_bits(int32_t num_bits) { _num_bits = num_bits; } + + bool narrow_range(void) const { return _narrow_range; } + void narrow_range(bool narrow_range) { _narrow_range = narrow_range; } + +private: + float _min{0.0f}; + float _max{0.0f}; + int32_t _num_bits{0}; + bool _narrow_range{false}; +}; + +} // namespace luci + +#endif // __LUCI_IR_CIRCLEGATHER_H__ diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFill.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFill.h index bfc65274a..183794d41 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFill.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFill.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFloor.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFloor.h index 7e10547b6..ce6807e98 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFloor.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFloor.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorDiv.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorDiv.h index ba9db010c..bf76e37b6 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorDiv.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorDiv.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorMod.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorMod.h index 4d13717a0..1af0af758 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorMod.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorMod.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h index 952befc87..2862cadb2 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h @@ -21,7 +21,7 @@ #include "luci/IR/CircleOpcode.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -30,9 +30,9 @@ namespace luci * @brief FULLY_CONNECTED in Circle */ class CircleFullyConnected final - : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::FULLY_CONNECTED>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc>, - public LuciNodeMixin<LuciNodeTrait::Bias> + : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::FULLY_CONNECTED>>, + public CircleNodeMixin<CircleNodeTrait::FusedActFunc>, + public CircleNodeMixin<CircleNodeTrait::Bias> { public: enum class WeightsFormat diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleGather.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleGather.h index 1e8c4982a..78fa2fc28 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleGather.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleGather.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -42,7 +42,7 @@ public: void axis(int32_t axis) { _axis = axis; } private: - int32_t _axis = 0; + int32_t _axis{0}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleGatherNd.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleGatherNd.h index 3423a8216..d6f34f1ea 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleGatherNd.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleGatherNd.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleGreater.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleGreater.h index 040a4e338..a03b6c749 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleGreater.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleGreater.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleGreaterEqual.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleGreaterEqual.h index 82bdab212..e435320b2 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleGreaterEqual.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleGreaterEqual.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,7 +29,7 @@ namespace luci * @brief GREATER EQUAL in Circle */ class CircleGreaterEqual final - : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::GREATER_EQUAL>> + : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::GREATER_EQUAL>> { public: loco::Node *x(void) const { return at(0)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleIf.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleIf.h index 2f9eac211..1c037a406 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleIf.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleIf.h @@ -34,7 +34,7 @@ class CircleIf final : public VariadicArityNode<CircleNodeImpl<CircleOpcode::IF> { public: CircleIf(uint32_t arity, uint32_t out) - : VariadicArityNode<CircleNodeImpl<CircleOpcode::IF>>(arity + 1), _output_count(out) + : VariadicArityNode<CircleNodeImpl<CircleOpcode::IF>>(arity + 1), _output_count(out) { assert(arity > 0); assert(out > 0); diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleIfOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleIfOut.h index 3654e943b..5adaaa447 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleIfOut.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleIfOut.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleIfOut final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEIFOUT>> { public: - CircleIfOut() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleInput.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleInput.h index 4a7d36a4e..e0be9aa6e 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleInput.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleInput.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" #include <loco/IR/DataTypeTraits.h> #include <loco/IR/GraphInputIndex.h> @@ -35,16 +35,13 @@ namespace luci class CircleInput final : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLEINPUT>> { public: - CircleInput() = default; - -public: void index(const loco::GraphInputIndex &index); loco::GraphInputIndex index(void) const; bool indexed(void) const { return _index != -1; } private: - int64_t _index = -1; // Uninitialized + int64_t _index{-1}; // Uninitialized }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleInstanceNorm.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleInstanceNorm.h index db0faa05e..65c34194d 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleInstanceNorm.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleInstanceNorm.h @@ -21,7 +21,7 @@ #include "luci/IR/CircleOpcode.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -30,8 +30,8 @@ namespace luci * @brief INSTANCE_NORM in Circle */ class CircleInstanceNorm final - : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::INSTANCE_NORM>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc> + : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::INSTANCE_NORM>>, + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> { public: /// @note Currently only support FLOAT32 as input node @@ -44,11 +44,12 @@ public: loco::Node *beta(void) const { return at(2)->node(); } void beta(loco::Node *node) { at(2)->node(node); } +public: float epsilon() const { return _epsilon; } void epsilon(float epsilon) { _epsilon = epsilon; } private: - float _epsilon = 1e-05; + float _epsilon{1e-05}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Normalize.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Normalize.h index efa932d95..eb2b372ce 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Normalize.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Normalize.h @@ -21,7 +21,7 @@ #include "luci/IR/CircleOpcode.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -30,8 +30,8 @@ namespace luci * @brief L2_NORMALIZATION in Circle */ class CircleL2Normalize final - : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::L2_NORMALIZATION>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc> + : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::L2_NORMALIZATION>>, + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> { public: loco::Node *x(void) const { return at(0)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Pool2D.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Pool2D.h index 7c76ee5d0..624d29e9e 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Pool2D.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Pool2D.h @@ -24,7 +24,7 @@ #include "luci/IR/AttrPadding.h" #include "luci/IR/AttrStride.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -33,15 +33,13 @@ namespace luci * @brief L2_POOL_2D in Circle */ class CircleL2Pool2D final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::L2_POOL_2D>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc> + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> { public: - CircleL2Pool2D() : _padding(Padding::UNDEFINED) { /* empty */} - -public: loco::Node *value(void) const { return at(0)->node(); } void value(loco::Node *node) { at(0)->node(node); } +public: Padding padding() const { return _padding; } void padding(Padding padding) { _padding = padding; } @@ -52,7 +50,7 @@ public: Stride *stride(void) { return &_stride; } private: - Padding _padding; + Padding _padding{Padding::UNDEFINED}; Stride _stride; Filter _filter; }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLeakyRelu.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLeakyRelu.h index d6ac97fc0..c8e93af91 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLeakyRelu.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLeakyRelu.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,17 +31,15 @@ namespace luci class CircleLeakyRelu final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::LEAKY_RELU>> { public: - CircleLeakyRelu() = default; - -public: loco::Node *features(void) const { return at(0)->node(); } void features(loco::Node *node) { at(0)->node(node); } +public: float alpha() const { return _alpha; } void alpha(float alpha) { _alpha = alpha; } private: - float _alpha = 0.2f; + float _alpha{0.2f}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLess.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLess.h index cd6cf1872..7adf67842 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLess.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLess.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLessEqual.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLessEqual.h index 4c7c6a49b..eb8962494 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLessEqual.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLessEqual.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLocalResponseNormalization.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLocalResponseNormalization.h index 8ad2b40fd..4d324700e 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLocalResponseNormalization.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLocalResponseNormalization.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,7 +29,7 @@ namespace luci * @brief LOCAL_RESPONSE_NORMALIZATION in Circle */ class CircleLocalResponseNormalization final - : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::LOCAL_RESPONSE_NORMALIZATION>> + : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::LOCAL_RESPONSE_NORMALIZATION>> { public: loco::Node *input(void) const { return at(0)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLog.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLog.h index aeb13fed9..2cc57ce2d 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLog.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLog.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogSoftmax.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogSoftmax.h index 5dfd2c1f9..b73ff7c2a 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogSoftmax.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogSoftmax.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalAnd.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalAnd.h index 975f6dbc7..9943c71cd 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalAnd.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalAnd.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalNot.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalNot.h index 749dbe518..369a3e7bf 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalNot.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalNot.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalOr.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalOr.h index 570be57af..c54ec3ebf 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalOr.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalOr.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogistic.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogistic.h index 8328cb328..1f95e0f77 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogistic.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogistic.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleLogistic final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::LOGISTIC>> { public: - CircleLogistic() = default; - -public: loco::Node *x(void) const { return at(0)->node(); } void x(loco::Node *node) { at(0)->node(node); } }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixDiag.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixDiag.h index dca6538c3..f8bf259f9 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixDiag.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixDiag.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixSetDiag.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixSetDiag.h index c1f5f3023..76aeaff40 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixSetDiag.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixSetDiag.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,7 +29,7 @@ namespace luci * @brief MATRIX_SET_DIAG in Circle */ class CircleMatrixSetDiag final - : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::MATRIX_SET_DIAG>> + : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::MATRIX_SET_DIAG>> { public: loco::Node *input(void) const { return at(0)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMaxPool2D.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMaxPool2D.h index 1eb6532ff..557240d54 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMaxPool2D.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMaxPool2D.h @@ -24,7 +24,7 @@ #include "luci/IR/AttrPadding.h" #include "luci/IR/AttrStride.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -33,15 +33,13 @@ namespace luci * @brief MAX_POOL_2D in Circle */ class CircleMaxPool2D final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::MAX_POOL_2D>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc> + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> { public: - CircleMaxPool2D() : _padding(Padding::UNDEFINED) { /* empty */} - -public: loco::Node *value(void) const { return at(0)->node(); } void value(loco::Node *node) { at(0)->node(node); } +public: Padding padding() const { return _padding; } void padding(Padding padding) { _padding = padding; } @@ -52,7 +50,7 @@ public: Stride *stride(void) { return &_stride; } private: - Padding _padding; + Padding _padding{Padding::UNDEFINED}; Stride _stride; Filter _filter; }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMaximum.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMaximum.h index 6f789bc14..317cea308 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMaximum.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMaximum.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMean.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMean.h index 7f8aeb5aa..f56e4f4c0 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMean.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMean.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -42,7 +42,7 @@ public: void keep_dims(bool keep_dims) { _keep_dims = keep_dims; } private: - bool _keep_dims = false; + bool _keep_dims{false}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMinimum.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMinimum.h index 79d5a6f17..959d9c93b 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMinimum.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMinimum.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMirrorPad.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMirrorPad.h index 68db8f6f3..c69e8f7c1 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMirrorPad.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMirrorPad.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" #include "luci/IR/AttrMirrorPadMode.h" namespace luci @@ -32,9 +32,6 @@ namespace luci class CircleMirrorPad final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::MIRROR_PAD>> { public: - CircleMirrorPad() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMul.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMul.h index 67e897170..85ed694b3 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMul.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMul.h @@ -21,7 +21,7 @@ #include "luci/IR/CircleOpcode.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -30,7 +30,7 @@ namespace luci * @brief MUL in Circle */ class CircleMul final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::MUL>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc> + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> { public: loco::Node *x(void) const { return at(0)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleNeg.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleNeg.h index 4149ac4a7..adea3fb83 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleNeg.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleNeg.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4.h index 69f3368c0..b47404bb0 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,7 +29,7 @@ namespace luci * @brief NON_MAX_SUPPRESSION_V4 in Circle */ class CircleNonMaxSuppressionV4 final - : public FixedArityNode<5, CircleNodeImpl<CircleOpcode::NON_MAX_SUPPRESSION_V4>> + : public FixedArityNode<5, CircleNodeImpl<CircleOpcode::NON_MAX_SUPPRESSION_V4>> { public: loco::Node *boxes(void) const { return at(0)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4Out.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4Out.h index a24dc3e9c..7e6923b5e 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4Out.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4Out.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,12 +29,9 @@ namespace luci * @brief Virtual NONMAXSUPPRESSIONV4OUT in Circle */ class CircleNonMaxSuppressionV4Out final - : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLENONMAXSUPPRESSIONV4OUT>> + : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLENONMAXSUPPRESSIONV4OUT>> { public: - CircleNonMaxSuppressionV4Out() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5.h index 52d682147..77086ede7 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,7 +29,7 @@ namespace luci * @brief NON_MAX_SUPPRESSION_V5 in Circle */ class CircleNonMaxSuppressionV5 final - : public FixedArityNode<6, CircleNodeImpl<CircleOpcode::NON_MAX_SUPPRESSION_V5>> + : public FixedArityNode<6, CircleNodeImpl<CircleOpcode::NON_MAX_SUPPRESSION_V5>> { public: loco::Node *boxes(void) const { return at(0)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5Out.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5Out.h index 0c6989cc7..63d061f11 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5Out.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5Out.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,12 +29,9 @@ namespace luci * @brief Virtual NONMAXSUPPRESSIONV5OUT in Circle */ class CircleNonMaxSuppressionV5Out final - : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLENONMAXSUPPRESSIONV5OUT>> + : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLENONMAXSUPPRESSIONV5OUT>> { public: - CircleNonMaxSuppressionV5Out() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleNotEqual.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleNotEqual.h index cca7a5e22..add6a0747 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleNotEqual.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleNotEqual.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleOneHot.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleOneHot.h index 665e01d48..b3eb0f436 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleOneHot.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleOneHot.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -48,7 +48,7 @@ public: void axis(int32_t axis) { _axis = axis; } private: - int32_t _axis = -1; + int32_t _axis{-1}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleOutput.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleOutput.h index 67e55f1a1..eb02f824e 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleOutput.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleOutput.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" #include <loco/IR/GraphOutputIndex.h> @@ -34,8 +34,6 @@ namespace luci class CircleOutput final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEOUTPUT>> { public: - CircleOutput() = default; - void index(const loco::GraphOutputIndex &index); loco::GraphOutputIndex index(void) const; @@ -46,7 +44,7 @@ public: void from(loco::Node *node) { at(0)->node(node); } private: - int64_t _index = -1; // Uninitialized + int64_t _index{-1}; // Uninitialized }; /** @@ -54,7 +52,7 @@ private: */ // TODO remove CircleOutputDummy class CircleOutputDummy final - : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLEOUTPUTDUMMY>> + : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLEOUTPUTDUMMY>> { public: CircleOutputDummy() = default; @@ -64,7 +62,7 @@ public: * @brief CircleOutputExclude is used to specifying not exported nodes */ class CircleOutputExclude final - : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLEOUTPUTEXCLUDE>> + : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLEOUTPUTEXCLUDE>> { public: CircleOutputExclude() = default; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CirclePRelu.h b/compiler/luci/lang/include/luci/IR/Nodes/CirclePRelu.h index 693777512..3c5559db2 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CirclePRelu.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CirclePRelu.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CirclePRelu final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::PRELU>> { public: - CirclePRelu() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CirclePad.h b/compiler/luci/lang/include/luci/IR/Nodes/CirclePad.h index 31599bda0..ede217789 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CirclePad.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CirclePad.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CirclePad final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::PAD>> { public: - CirclePad() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CirclePadV2.h b/compiler/luci/lang/include/luci/IR/Nodes/CirclePadV2.h index 563cfd9a4..644e2bb27 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CirclePadV2.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CirclePadV2.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CirclePadV2 final : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::PADV2>> { public: - CirclePadV2() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CirclePow.h b/compiler/luci/lang/include/luci/IR/Nodes/CirclePow.h index 006e3dd86..40c5a829d 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CirclePow.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CirclePow.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CirclePow final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::POW>> { public: - CirclePow() = default; - -public: loco::Node *x(void) const { return at(0)->node(); } void x(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRange.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRange.h index 977a37a52..56f8a2eba 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleRange.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRange.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRank.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRank.h index ba6d67f69..034f251bc 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleRank.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRank.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceAny.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceAny.h index 0456be863..c64dbbdf8 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceAny.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceAny.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -42,7 +42,7 @@ public: void keep_dims(bool keep_dims) { _keep_dims = keep_dims; } private: - bool _keep_dims = false; + bool _keep_dims{false}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMax.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMax.h index 925c977e5..97cbecd08 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMax.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMax.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -42,7 +42,7 @@ public: void keep_dims(bool keep_dims) { _keep_dims = keep_dims; } private: - bool _keep_dims = false; + bool _keep_dims{false}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMin.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMin.h index fd789ae5e..33708928f 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMin.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMin.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -42,7 +42,7 @@ public: void keep_dims(bool keep_dims) { _keep_dims = keep_dims; } private: - bool _keep_dims = false; + bool _keep_dims{false}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceProd.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceProd.h index b7d226255..3689ee532 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceProd.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceProd.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -42,7 +42,7 @@ public: void keep_dims(bool keep_dims) { _keep_dims = keep_dims; } private: - bool _keep_dims = false; + bool _keep_dims{false}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu.h index 91272d2bf..6148caa03 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleRelu final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::RELU>> { public: - CircleRelu() = default; - -public: loco::Node *features(void) const { return at(0)->node(); } void features(loco::Node *node) { at(0)->node(node); } }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu6.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu6.h index b4274ded9..0fa25e873 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu6.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu6.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleRelu6 final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::RELU6>> { public: - CircleRelu6() = default; - -public: loco::Node *features(void) const { return at(0)->node(); } void features(loco::Node *node) { at(0)->node(node); } }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReluN1To1.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReluN1To1.h index a5c5710c2..13c0d166f 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReluN1To1.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReluN1To1.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleReluN1To1 final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::RELU_N1_TO_1>> { public: - CircleReluN1To1() = default; - -public: loco::Node *features(void) const { return at(0)->node(); } void features(loco::Node *node) { at(0)->node(node); } }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReshape.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReshape.h index b13144f7e..090df4044 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReshape.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReshape.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,14 +31,11 @@ namespace luci class CircleReshape final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::RESHAPE>> { public: - CircleReshape() = default; - -public: loco::Node *tensor(void) const { return at(0)->node(); } void tensor(loco::Node *node) { at(0)->node(node); } // NOTE shape is optional and can be CircleConst or any other type - // and also can be CircleOutputDummy when reshape option does not exist + // and also should be CircleOutputDummy when reshape option does not exist loco::Node *shape(void) const { return at(1)->node(); } void shape(loco::Node *node) { at(1)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeBilinear.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeBilinear.h index 3c8223338..091916a2b 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeBilinear.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeBilinear.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,18 +29,16 @@ namespace luci * @brief RESIZE_BILINEAR in Circle */ class CircleResizeBilinear final - : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::RESIZE_BILINEAR>> + : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::RESIZE_BILINEAR>> { public: - CircleResizeBilinear() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } loco::Node *size(void) const { return at(1)->node(); } void size(loco::Node *node) { at(1)->node(node); } +public: bool align_corners() const { return _align_corners; } void align_corners(bool value) { _align_corners = value; } @@ -48,8 +46,8 @@ public: void half_pixel_centers(bool value) { _half_pixel_centers = value; } private: - bool _align_corners = false; - bool _half_pixel_centers = false; + bool _align_corners{false}; + bool _half_pixel_centers{false}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeNearestNeighbor.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeNearestNeighbor.h index dc32ebee7..ab880d767 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeNearestNeighbor.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeNearestNeighbor.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,23 +29,21 @@ namespace luci * @brief RESIZE_NEAREST_NEIGHBOR in Circle */ class CircleResizeNearestNeighbor final - : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::RESIZE_NEAREST_NEIGHBOR>> + : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::RESIZE_NEAREST_NEIGHBOR>> { public: - CircleResizeNearestNeighbor() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } loco::Node *size(void) const { return at(1)->node(); } void size(loco::Node *node) { at(1)->node(node); } +public: bool align_corners() const { return _align_corners; } void align_corners(bool value) { _align_corners = value; } private: - bool _align_corners = false; + bool _align_corners{false}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseSequence.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseSequence.h index b0766dd3e..5f089a768 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseSequence.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseSequence.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,12 +29,9 @@ namespace luci * @brief REVERSE_SEQUENCE in Circle */ class CircleReverseSequence final - : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::REVERSE_SEQUENCE>> + : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::REVERSE_SEQUENCE>> { public: - CircleReverseSequence() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } @@ -42,15 +39,15 @@ public: void seq_lengths(loco::Node *node) { at(1)->node(node); } public: - int seq_axis(void) const { return _seq_axis; } - void seq_axis(int seq_axis) { _seq_axis = seq_axis; } + int32_t seq_axis(void) const { return _seq_axis; } + void seq_axis(int32_t seq_axis) { _seq_axis = seq_axis; } - int batch_axis(void) const { return _batch_axis; } - void batch_axis(int batch_axis) { _batch_axis = batch_axis; } + int32_t batch_axis(void) const { return _batch_axis; } + void batch_axis(int32_t batch_axis) { _batch_axis = batch_axis; } private: - int _seq_axis{0}; - int _batch_axis{0}; + int32_t _seq_axis{0}; + int32_t _batch_axis{0}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseV2.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseV2.h index 71d9f65aa..96b6a793d 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseV2.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseV2.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRound.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRound.h index 30296ce9e..e340266ed 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleRound.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRound.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleRound final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::ROUND>> { public: - CircleRound() = default; - -public: loco::Node *x(void) const { return at(0)->node(); } void x(loco::Node *node) { at(0)->node(node); } }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRsqrt.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRsqrt.h index 873397bce..7907f326b 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleRsqrt.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRsqrt.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleRsqrt final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::RSQRT>> { public: - CircleRsqrt() = default; - -public: loco::Node *x(void) const { return at(0)->node(); } void x(loco::Node *node) { at(0)->node(node); } }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleScatterNd.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleScatterNd.h index 9f93a0a80..fda3abafc 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleScatterNd.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleScatterNd.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSegmentSum.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSegmentSum.h index 416d617b2..e7227e9ee 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSegmentSum.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSegmentSum.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleSegmentSum final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::SEGMENT_SUM>> { public: - CircleSegmentSum() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSelect.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSelect.h index 727647168..6f778d72d 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSelect.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSelect.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleSelect final : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::SELECT>> { public: - CircleSelect() = default; - -public: loco::Node *condition(void) const { return at(0)->node(); } void condition(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSelectV2.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSelectV2.h index 7ac3c0524..7969cc2aa 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSelectV2.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSelectV2.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleSelectV2 final : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::SELECT_V2>> { public: - CircleSelectV2() = default; - -public: loco::Node *condition(void) const { return at(0)->node(); } void condition(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleShape.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleShape.h index ff20ce684..903894dbd 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleShape.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleShape.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleShape final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::SHAPE>> { public: - CircleShape() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSin.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSin.h index 5624db253..25dc18b0d 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSin.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSin.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSlice.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSlice.h index a2113643d..98556d7a6 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSlice.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSlice.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSoftmax.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSoftmax.h index 7166a329b..d10cb1682 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSoftmax.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSoftmax.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToBatchND.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToBatchND.h index 042ebffcd..ef715c6d0 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToBatchND.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToBatchND.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,7 +29,7 @@ namespace luci * @brief SPACE_TO_BATCH_ND in Circle */ class CircleSpaceToBatchND final - : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::SPACE_TO_BATCH_ND>> + : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::SPACE_TO_BATCH_ND>> { public: loco::Node *input(void) const { return at(0)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToDepth.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToDepth.h index 420a4cb96..387e0d80f 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToDepth.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToDepth.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,18 +29,18 @@ namespace luci * @brief SPACE_TO_DEPTH in Circle */ class CircleSpaceToDepth final - : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::SPACE_TO_DEPTH>> + : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::SPACE_TO_DEPTH>> { public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } public: - int block_size(void) const { return _block_size; } - void block_size(int block_size) { _block_size = block_size; } + int32_t block_size(void) const { return _block_size; } + void block_size(int32_t block_size) { _block_size = block_size; } private: - int _block_size{0}; + int32_t _block_size{0}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSparseToDense.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSparseToDense.h index 7e80304b0..94a20c064 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSparseToDense.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSparseToDense.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,7 +29,7 @@ namespace luci * @brief SPARSE_TO_DENSE in Circle */ class CircleSparseToDense final - : public FixedArityNode<4, CircleNodeImpl<CircleOpcode::SPARSE_TO_DENSE>> + : public FixedArityNode<4, CircleNodeImpl<CircleOpcode::SPARSE_TO_DENSE>> { public: loco::Node *indices(void) const { return at(0)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplit.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplit.h index 0eda19501..0cb953131 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplit.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplit.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitOut.h index 6bf4a9fef..a507740e4 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitOut.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitOut.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleSplitOut final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLESPLITOUT>> { public: - CircleSplitOut() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitV.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitV.h index 1b7d55534..cb02cbbcf 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitV.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitV.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitVOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitVOut.h index d3b2f1e5a..adf79f30c 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitVOut.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitVOut.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,12 +29,9 @@ namespace luci * @brief Virtual CIRCLESPLITVOUT in Circle */ class CircleSplitVOut final - : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLESPLITVOUT>> + : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLESPLITVOUT>> { public: - CircleSplitVOut() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSqrt.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSqrt.h index c96ca8498..b76bd1ad5 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSqrt.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSqrt.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleSqrt final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::SQRT>> { public: - CircleSqrt() = default; - -public: loco::Node *x(void) const { return at(0)->node(); } void x(loco::Node *node) { at(0)->node(node); } }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSquare.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSquare.h index a29edfe82..3f9228b3b 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSquare.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSquare.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleSquare final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::SQUARE>> { public: - CircleSquare() = default; - -public: loco::Node *x(void) const { return at(0)->node(); } void x(loco::Node *node) { at(0)->node(node); } }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSquaredDifference.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSquaredDifference.h index b5b39f920..355c9f3d3 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSquaredDifference.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSquaredDifference.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,12 +29,9 @@ namespace luci * @brief SQUARED_DIFFERENCE in Circle */ class CircleSquaredDifference final - : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::SQUARED_DIFFERENCE>> + : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::SQUARED_DIFFERENCE>> { public: - CircleSquaredDifference() = default; - -public: loco::Node *x(void) const { return at(0)->node(); } void x(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSqueeze.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSqueeze.h index f175f1411..ba71ff217 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSqueeze.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSqueeze.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleSqueeze final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::SQUEEZE>> { public: - CircleSqueeze() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleStridedSlice.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleStridedSlice.h index 98799fec1..6a4155ef1 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleStridedSlice.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleStridedSlice.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,7 +29,7 @@ namespace luci * @brief STRIDED_SLICE in Circle */ class CircleStridedSlice final - : public FixedArityNode<4, CircleNodeImpl<CircleOpcode::STRIDED_SLICE>> + : public FixedArityNode<4, CircleNodeImpl<CircleOpcode::STRIDED_SLICE>> { public: loco::Node *input(void) const { return at(0)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSub.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSub.h index 08208f942..d9aaa44e5 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSub.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSub.h @@ -21,7 +21,7 @@ #include "luci/IR/CircleOpcode.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -30,12 +30,9 @@ namespace luci * @brief SUB in Circle */ class CircleSub final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::SUB>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc> + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> { public: - CircleSub() = default; - -public: loco::Node *x(void) const { return at(0)->node(); } void x(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSum.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSum.h index 21faa76fe..a72e18f54 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSum.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSum.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleTanh.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleTanh.h index f7444921f..2036a7301 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleTanh.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleTanh.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleTanh final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::TANH>> { public: - CircleTanh() = default; - -public: loco::Node *x(void) const { return at(0)->node(); } void x(loco::Node *node) { at(0)->node(node); } }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleTile.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleTile.h index 96e1f69c6..1ec2f5e82 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleTile.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleTile.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleTile final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::TILE>> { public: - CircleTile() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2.h index 3b2b5abb7..0bf78c3ee 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleTopKV2 final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::TOPK_V2>> { public: - CircleTopKV2() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2Out.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2Out.h index 5a6dd0c02..f1a6b4a41 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2Out.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2Out.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,12 +29,9 @@ namespace luci * @brief Virtual CIRCLETOPKV2OUT in Circle */ class CircleTopKV2Out final - : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLETOPKV2OUT>> + : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLETOPKV2OUT>> { public: - CircleTopKV2Out() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleTranspose.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleTranspose.h index 095cd6746..72ce0738c 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleTranspose.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleTranspose.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,13 +31,7 @@ namespace luci class CircleTranspose final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::TRANSPOSE>> { public: - CircleTranspose() = default; - -public: - /// @brief Get the input node to transpose loco::Node *a(void) const { return at(0)->node(); } - - /// @brief Set the input node to transpose void a(loco::Node *node) { at(0)->node(node); } loco::Node *perm(void) const { return at(1)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleTransposeConv.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleTransposeConv.h index e355102d6..5ae41c0c4 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleTransposeConv.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleTransposeConv.h @@ -22,7 +22,7 @@ #include "luci/IR/AttrPadding.h" #include "luci/IR/AttrStride.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -34,8 +34,8 @@ namespace luci * 'out' acutally means 'out' and 'in' of the this node. */ class CircleTransposeConv final - : public FixedArityNode<4, CircleNodeImpl<CircleOpcode::TRANSPOSE_CONV>>, - public LuciNodeMixin<LuciNodeTrait::Bias> + : public FixedArityNode<4, CircleNodeImpl<CircleOpcode::TRANSPOSE_CONV>>, + public CircleNodeMixin<CircleNodeTrait::Bias> { public: loco::Node *inputSizes(void) const { return at(0)->node(); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnidirectionalSequenceLSTM.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnidirectionalSequenceLSTM.h index 4352b045b..faf0ec94d 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnidirectionalSequenceLSTM.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnidirectionalSequenceLSTM.h @@ -21,7 +21,7 @@ #include "luci/IR/CircleOpcode.h" #include "luci/IR/AttrFusedActFunc.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -30,8 +30,8 @@ namespace luci * @brief UNIDIRECTIONAL_SEQUENCE_LSTM in Circle */ class CircleUnidirectionalSequenceLSTM final - : public FixedArityNode<24, CircleNodeImpl<CircleOpcode::UNIDIRECTIONAL_SEQUENCE_LSTM>>, - public LuciNodeMixin<LuciNodeTrait::FusedActFunc> + : public FixedArityNode<24, CircleNodeImpl<CircleOpcode::UNIDIRECTIONAL_SEQUENCE_LSTM>>, + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> { public: loco::Node *input(void) const { return at(0)->node(); } @@ -104,10 +104,10 @@ public: } private: - float _cell_clip = 0.0f; - float _proj_clip = 0.0f; - bool _time_major = false; - bool _asymmetric_quantize_inputs = false; + float _cell_clip{0.0f}; + float _proj_clip{0.0f}; + bool _time_major{false}; + bool _asymmetric_quantize_inputs{false}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnique.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnique.h index 719a72362..2dd48b2f9 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnique.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnique.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -36,7 +36,7 @@ public: public: loco::DataType idx_out_type(void) const { return _idx_out_type; } - void output_type(loco::DataType ot) { _idx_out_type = ot; } + void idx_out_type(loco::DataType ot) { _idx_out_type = ot; } private: loco::DataType _idx_out_type{loco::DataType::S32}; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleUniqueOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleUniqueOut.h index f846403e0..233351860 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleUniqueOut.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleUniqueOut.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,12 +29,9 @@ namespace luci * @brief Virtual CIRCLEUNIQUEOUT in Circle */ class CircleUniqueOut final - : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEUNIQUEOUT>> + : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEUNIQUEOUT>> { public: - CircleUniqueOut() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpack.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpack.h index cb91d7e6a..fd0c66ce0 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpack.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpack.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleUnpack final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::UNPACK>> { public: - CircleUnpack() = default; - -public: loco::Node *value(void) const { return at(0)->node(); } void value(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpackOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpackOut.h index 6f24578a1..640d2f1bb 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpackOut.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpackOut.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -29,12 +29,9 @@ namespace luci * @brief Virtual CIRCLEUNPACKOUT in Circle */ class CircleUnpackOut final - : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEUNPACKOUT>> + : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEUNPACKOUT>> { public: - CircleUnpackOut() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleWhere.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleWhere.h index 51eda3d6e..8895bcbbd 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleWhere.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleWhere.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" #include <cassert> @@ -33,9 +33,6 @@ namespace luci class CircleWhere final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::WHERE>> { public: - CircleWhere() = default; - -public: loco::Node *condition() const { return at(0)->node(); } void condition(loco::Node *node) { at(0)->node(node); } }; diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleWhile.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleWhile.h index 40ec96414..f4154d3ab 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleWhile.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleWhile.h @@ -34,7 +34,7 @@ class CircleWhile final : public VariadicArityNode<CircleNodeImpl<CircleOpcode:: { public: CircleWhile(uint32_t arity, uint32_t out) - : VariadicArityNode<CircleNodeImpl<CircleOpcode::WHILE>>(arity), _output_count(out) + : VariadicArityNode<CircleNodeImpl<CircleOpcode::WHILE>>(arity), _output_count(out) { assert(arity > 0); assert(out > 0); diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleWhileOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleWhileOut.h index cdf617848..98efc21e5 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleWhileOut.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleWhileOut.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,9 +31,6 @@ namespace luci class CircleWhileOut final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEWHILEOUT>> { public: - CircleWhileOut() = default; - -public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleZerosLike.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleZerosLike.h index d3b6d272a..9302facd0 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleZerosLike.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleZerosLike.h @@ -20,7 +20,7 @@ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" -#include "luci/IR/LuciNodeMixins.h" +#include "luci/IR/CircleNodeMixins.h" namespace luci { @@ -31,13 +31,7 @@ namespace luci class CircleZerosLike final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::ZEROS_LIKE>> { public: - CircleZerosLike() = default; - -public: - /// @brief Get the input node loco::Node *input(void) const { return at(0)->node(); } - - /// @brief Set the input node void input(loco::Node *node) { at(0)->node(node); } }; diff --git a/compiler/luci/lang/include/luci/IR/SparsityParam.h b/compiler/luci/lang/include/luci/IR/SparsityParam.h index f471e5ef9..6cfff67e1 100644 --- a/compiler/luci/lang/include/luci/IR/SparsityParam.h +++ b/compiler/luci/lang/include/luci/IR/SparsityParam.h @@ -44,7 +44,7 @@ class SparseIndexVector public: SparseIndexVector() = default; SparseIndexVector(const SparseIndexVectorType &type, const std::vector<int32_t> &sparse_index_vec) - : _type{type} + : _type{type} { switch (type) { @@ -53,7 +53,7 @@ public: case SparseIndexVectorType::I32: { _vec_ptr = static_cast<void *>( - new std::vector<int32_t>(sparse_index_vec.begin(), sparse_index_vec.end())); + new std::vector<int32_t>(sparse_index_vec.begin(), sparse_index_vec.end())); break; } case SparseIndexVectorType::U16: @@ -90,21 +90,21 @@ public: case SparseIndexVectorType::I32: { const std::vector<int32_t> *vec = - static_cast<const std::vector<int32_t> *>(sparse_index_vec); + static_cast<const std::vector<int32_t> *>(sparse_index_vec); _vec_ptr = static_cast<void *>(new std::vector<int32_t>(vec->begin(), vec->end())); break; } case SparseIndexVectorType::U16: { const std::vector<uint16_t> *vec = - static_cast<const std::vector<uint16_t> *>(sparse_index_vec); + static_cast<const std::vector<uint16_t> *>(sparse_index_vec); _vec_ptr = static_cast<void *>(new std::vector<uint16_t>(vec->begin(), vec->end())); break; } case SparseIndexVectorType::U8: { const std::vector<uint8_t> *vec = - static_cast<const std::vector<uint8_t> *>(sparse_index_vec); + static_cast<const std::vector<uint8_t> *>(sparse_index_vec); _vec_ptr = static_cast<void *>(new std::vector<uint8_t>(vec->begin(), vec->end())); break; } @@ -114,12 +114,12 @@ public: } SparseIndexVector(const SparseIndexVector &sparse_index_vec) - : SparseIndexVector(sparse_index_vec._type, sparse_index_vec._vec_ptr) + : SparseIndexVector(sparse_index_vec._type, sparse_index_vec._vec_ptr) { } SparseIndexVector(SparseIndexVector &&sparse_index_vec) - : _type{sparse_index_vec._type}, _vec_ptr{std::exchange(sparse_index_vec._vec_ptr, nullptr)} + : _type{sparse_index_vec._type}, _vec_ptr{std::exchange(sparse_index_vec._vec_ptr, nullptr)} { } @@ -178,8 +178,8 @@ public: const std::vector<uint16_t> *as_uint16_vector(void) const { return _type == SparseIndexVectorType::U16 - ? static_cast<const std::vector<uint16_t> *>(_vec_ptr) - : nullptr; + ? static_cast<const std::vector<uint16_t> *>(_vec_ptr) + : nullptr; } const std::vector<uint8_t> *as_uint8_vector(void) const { @@ -202,8 +202,8 @@ public: } DimMetaData(DimensionType format, int32_t dense_size, const SparseIndexVector &array_segments, const SparseIndexVector &array_indices) - : _format{format}, _dense_size{dense_size}, _array_segments{array_segments}, - _array_indices{array_indices} + : _format{format}, _dense_size{dense_size}, _array_segments{array_segments}, _array_indices{ + array_indices} { // DO NOTHING } diff --git a/compiler/luci/lang/src/CircleDialect.cpp b/compiler/luci/lang/src/CircleDialect.cpp index 42ca3c917..0d315fc55 100644 --- a/compiler/luci/lang/src/CircleDialect.cpp +++ b/compiler/luci/lang/src/CircleDialect.cpp @@ -15,6 +15,7 @@ */ #include "luci/IR/CircleDialect.h" +#include "luci/IR/DeadNodeQueryService.h" #include "luci/IR/Nodes/CircleInput.h" #include "luci/IR/Nodes/CircleOutput.h" @@ -22,8 +23,6 @@ #include <loco/IR/GraphInputIndex.h> #include <loco/IR/GraphOutputIndex.h> -#include "DeadNodeQueryService.h" - #include <cassert> #include <memory> diff --git a/compiler/luci/lang/src/LuciNodeMixins.cpp b/compiler/luci/lang/src/CircleNodeMixins.cpp index 660cbe1a5..f72178df5 100644 --- a/compiler/luci/lang/src/LuciNodeMixins.cpp +++ b/compiler/luci/lang/src/CircleNodeMixins.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,5 +14,5 @@ * limitations under the License. */ -// This is to validate LuciNodeMixins.h -#include "luci/IR/LuciNodeMixins.h" +// This is to validate CircleNodeMixins.h +#include "luci/IR/CircleNodeMixins.h" diff --git a/compiler/luci/lang/src/CircleNodes.cpp b/compiler/luci/lang/src/CircleNodes.cpp index c77c06861..2c2688c9e 100644 --- a/compiler/luci/lang/src/CircleNodes.cpp +++ b/compiler/luci/lang/src/CircleNodes.cpp @@ -23,31 +23,6 @@ namespace luci { -void set_new_shape(CircleReshape *node, int32_t *base, uint32_t size) -{ - // Check node does not have both of new shape infos - LUCI_ASSERT(node->shape() == nullptr, "node already has shape input"); - LUCI_ASSERT(node->newShape()->rank() == 0, "node already has newShape attribute"); - - const loco::DataType S32 = loco::DataType::S32; - - // Set 2nd input as CircleConst - auto const_shape_node = node->graph()->nodes()->create<CircleConst>(); - const_shape_node->rank(1); - const_shape_node->dim(0) = size; - const_shape_node->dtype(S32); - const_shape_node->size<S32>(size); - const_shape_node->shape_status(luci::ShapeStatus::VALID); - for (uint32_t axis = 0; axis < size; ++axis) - const_shape_node->at<S32>(axis) = base[axis]; - node->shape(const_shape_node); - - // Set newShape attribute - node->newShape()->rank(size); - for (uint32_t axis = 0; axis < size; ++axis) - node->newShape()->dim(axis) = base[axis]; -} - void link(loco::GraphOutput *output, CircleOutput *node) { node->index(output->index()); } CircleOutput *output_node(loco::Graph *g, const loco::GraphOutputIndex &index) diff --git a/compiler/luci/lang/src/DeadNodeQueryService.cpp b/compiler/luci/lang/src/DeadNodeQueryService.cpp index a22574c94..7dac08b5f 100644 --- a/compiler/luci/lang/src/DeadNodeQueryService.cpp +++ b/compiler/luci/lang/src/DeadNodeQueryService.cpp @@ -14,9 +14,8 @@ * limitations under the License. */ -#include "DeadNodeQueryService.h" - #include "luci/IR/CircleNodeVisitor.h" +#include "luci/IR/DeadNodeQueryService.h" #include <loco/IR/Graph.h> diff --git a/compiler/luci/lang/src/Nodes/CircleBatchMatMul.test.cpp b/compiler/luci/lang/src/Nodes/CircleBatchMatMul.test.cpp index d7712c8dd..3859d7fca 100644 --- a/compiler/luci/lang/src/Nodes/CircleBatchMatMul.test.cpp +++ b/compiler/luci/lang/src/Nodes/CircleBatchMatMul.test.cpp @@ -26,7 +26,7 @@ TEST(CircleBatchMatMulTest, constructor) luci::CircleBatchMatMul batchmatmul_node; ASSERT_EQ(luci::CircleDialect::get(), batchmatmul_node.dialect()); - ASSERT_EQ(luci::CircleOpcode::BATCHMATMUL, batchmatmul_node.opcode()); + ASSERT_EQ(luci::CircleOpcode::BATCH_MATMUL, batchmatmul_node.opcode()); ASSERT_EQ(nullptr, batchmatmul_node.x()); ASSERT_EQ(nullptr, batchmatmul_node.y()); diff --git a/compiler/luci/lang/src/Nodes/CircleBidrectionalSequenceLSTM.test.cpp b/compiler/luci/lang/src/Nodes/CircleBidrectionalSequenceLSTM.test.cpp new file mode 100644 index 000000000..3f13422e5 --- /dev/null +++ b/compiler/luci/lang/src/Nodes/CircleBidrectionalSequenceLSTM.test.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/Nodes/CircleBidirectionalSequenceLSTM.h" + +#include "luci/IR/CircleDialect.h" +#include "luci/IR/CircleNodeVisitor.h" + +#include <gtest/gtest.h> + +TEST(CircleBidirectionalSequenceLSTMTest, constructor_P) +{ + luci::CircleBidirectionalSequenceLSTM trc_node; + + ASSERT_EQ(luci::CircleDialect::get(), trc_node.dialect()); + ASSERT_EQ(luci::CircleOpcode::BIDIRECTIONAL_SEQUENCE_LSTM, trc_node.opcode()); + + ASSERT_EQ(nullptr, trc_node.input()); + + ASSERT_EQ(nullptr, trc_node.fw_input_to_input_weights()); + ASSERT_EQ(nullptr, trc_node.fw_input_to_forget_weights()); + ASSERT_EQ(nullptr, trc_node.fw_input_to_cell_weights()); + ASSERT_EQ(nullptr, trc_node.fw_input_to_output_weights()); + + ASSERT_EQ(nullptr, trc_node.fw_recurrent_to_input_weights()); + ASSERT_EQ(nullptr, trc_node.fw_recurrent_to_forget_weights()); + ASSERT_EQ(nullptr, trc_node.fw_recurrent_to_cell_weights()); + ASSERT_EQ(nullptr, trc_node.fw_recurrent_to_output_weights()); + + ASSERT_EQ(nullptr, trc_node.fw_cell_to_input_weights()); + ASSERT_EQ(nullptr, trc_node.fw_cell_to_forget_weights()); + ASSERT_EQ(nullptr, trc_node.fw_cell_to_output_weights()); + + ASSERT_EQ(nullptr, trc_node.fw_input_gate_bias()); + ASSERT_EQ(nullptr, trc_node.fw_forget_gate_bias()); + ASSERT_EQ(nullptr, trc_node.fw_cell_gate_bias()); + ASSERT_EQ(nullptr, trc_node.fw_output_gate_bias()); + + ASSERT_EQ(nullptr, trc_node.fw_projection_weights()); + ASSERT_EQ(nullptr, trc_node.fw_projection_bias()); + + ASSERT_EQ(nullptr, trc_node.bw_input_to_input_weights()); + ASSERT_EQ(nullptr, trc_node.bw_input_to_forget_weights()); + ASSERT_EQ(nullptr, trc_node.bw_input_to_cell_weights()); + ASSERT_EQ(nullptr, trc_node.bw_input_to_output_weights()); + + ASSERT_EQ(nullptr, trc_node.bw_recurrent_to_input_weights()); + ASSERT_EQ(nullptr, trc_node.bw_recurrent_to_forget_weights()); + ASSERT_EQ(nullptr, trc_node.bw_recurrent_to_cell_weights()); + ASSERT_EQ(nullptr, trc_node.bw_recurrent_to_output_weights()); + + ASSERT_EQ(nullptr, trc_node.bw_cell_to_input_weights()); + ASSERT_EQ(nullptr, trc_node.bw_cell_to_forget_weights()); + ASSERT_EQ(nullptr, trc_node.bw_cell_to_output_weights()); + + ASSERT_EQ(nullptr, trc_node.bw_input_gate_bias()); + ASSERT_EQ(nullptr, trc_node.bw_forget_gate_bias()); + ASSERT_EQ(nullptr, trc_node.bw_cell_gate_bias()); + ASSERT_EQ(nullptr, trc_node.bw_output_gate_bias()); + + ASSERT_EQ(nullptr, trc_node.bw_projection_weights()); + ASSERT_EQ(nullptr, trc_node.bw_projection_bias()); + + ASSERT_EQ(nullptr, trc_node.fw_activation_state()); + ASSERT_EQ(nullptr, trc_node.fw_cell_state()); + ASSERT_EQ(nullptr, trc_node.bw_activation_state()); + ASSERT_EQ(nullptr, trc_node.bw_cell_state()); + + ASSERT_EQ(nullptr, trc_node.auxillary_input()); + ASSERT_EQ(nullptr, trc_node.fw_auxillary_input_to_input_weights()); + ASSERT_EQ(nullptr, trc_node.fw_auxillary_input_to_forget_weights()); + ASSERT_EQ(nullptr, trc_node.fw_auxillary_input_to_cell_weights()); + ASSERT_EQ(nullptr, trc_node.fw_auxillary_input_to_output_weights()); + ASSERT_EQ(nullptr, trc_node.bw_auxillary_input_to_input_weights()); + ASSERT_EQ(nullptr, trc_node.bw_auxillary_input_to_forget_weights()); + ASSERT_EQ(nullptr, trc_node.bw_auxillary_input_to_cell_weights()); + ASSERT_EQ(nullptr, trc_node.bw_auxillary_input_to_output_weights()); + + ASSERT_EQ(luci::FusedActFunc::UNDEFINED, trc_node.fusedActivationFunction()); + ASSERT_EQ(0.f, trc_node.cell_clip()); + ASSERT_EQ(0.f, trc_node.proj_clip()); + ASSERT_EQ(false, trc_node.merge_outputs()); + ASSERT_EQ(false, trc_node.time_major()); + ASSERT_EQ(false, trc_node.asymmetric_quantize_inputs()); +} + +TEST(CircleBidirectionalSequenceLSTMTest, arity_NEG) +{ + luci::CircleBidirectionalSequenceLSTM trc_node; + + ASSERT_NO_THROW(trc_node.arg(36)); + ASSERT_THROW(trc_node.arg(48), std::out_of_range); +} + +TEST(CircleBidirectionalSequenceLSTMTest, visit_mutable_NEG) +{ + struct TestVisitor final : public luci::CircleNodeMutableVisitor<void> + { + }; + + luci::CircleBidirectionalSequenceLSTM trc_node; + + TestVisitor tv; + ASSERT_THROW(trc_node.accept(&tv), std::exception); +} + +TEST(CircleBidirectionalSequenceLSTMTest, visit_NEG) +{ + struct TestVisitor final : public luci::CircleNodeVisitor<void> + { + }; + + luci::CircleBidirectionalSequenceLSTM trc_node; + + TestVisitor tv; + ASSERT_THROW(trc_node.accept(&tv), std::exception); +} diff --git a/compiler/luci/lang/src/Nodes/CircleConst.test.cpp b/compiler/luci/lang/src/Nodes/CircleConst.test.cpp new file mode 100644 index 000000000..a81f4b00d --- /dev/null +++ b/compiler/luci/lang/src/Nodes/CircleConst.test.cpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/Nodes/CircleConst.h" + +#include "luci/IR/CircleDialect.h" +#include "luci/IR/CircleNodeVisitor.h" + +#include <gtest/gtest.h> + +TEST(CircleConstTest, constructor) +{ + luci::CircleConst const_node; + + ASSERT_EQ(luci::CircleDialect::get(), const_node.dialect()); + ASSERT_EQ(luci::CircleOpcode::CIRCLECONST, const_node.opcode()); +} + +TEST(CircleConstTest, dype_size) +{ + luci::CircleConst const_node; + + const_node.dtype(loco::DataType::S32); + const_node.size<loco::DataType::S32>(1); + + ASSERT_EQ(loco::DataType::S32, const_node.dtype()); + ASSERT_EQ(1, const_node.size<loco::DataType::S32>()); +} + +TEST(CircleConstTest, scalar) +{ + luci::CircleConst const_node; + + const_node.dtype(loco::DataType::S32); + const_node.size<loco::DataType::S32>(1); + const_node.scalar<loco::DataType::S32>() = 1; + + auto const &cs = const_node.scalar<loco::DataType::S32>(); + ASSERT_EQ(1, cs); +} diff --git a/compiler/luci/lang/src/Nodes/CircleCustom.test.cpp b/compiler/luci/lang/src/Nodes/CircleCustom.test.cpp index c07268cbf..76b70f38b 100644 --- a/compiler/luci/lang/src/Nodes/CircleCustom.test.cpp +++ b/compiler/luci/lang/src/Nodes/CircleCustom.test.cpp @@ -22,7 +22,7 @@ TEST(CircleCustomTest, constructor) { - luci::CircleCustom custom_node(2); + luci::CircleCustom custom_node(2, 1); ASSERT_EQ(luci::CircleDialect::get(), custom_node.dialect()); ASSERT_EQ(luci::CircleOpcode::CUSTOM, custom_node.opcode()); @@ -33,18 +33,19 @@ TEST(CircleCustomTest, constructor) ASSERT_EQ(2, custom_node.numInputs()); ASSERT_EQ(0, custom_node.custom_code().size()); + ASSERT_EQ(1, custom_node.numOutputs()); } TEST(CircleCustomTest, constructor_NEG) { - ASSERT_DEBUG_DEATH(luci::CircleCustom{0}, ""); + ASSERT_DEBUG_DEATH(luci::CircleCustom(0, 0), ""); SUCCEED(); } TEST(CircleCustomTest, invalidIndex_NEG) { - luci::CircleCustom custom_node(2); + luci::CircleCustom custom_node(2, 1); EXPECT_ANY_THROW(custom_node.arg(5)); } diff --git a/compiler/luci/lang/src/Nodes/CircleFakeQuant.test.cpp b/compiler/luci/lang/src/Nodes/CircleFakeQuant.test.cpp new file mode 100644 index 000000000..912e40570 --- /dev/null +++ b/compiler/luci/lang/src/Nodes/CircleFakeQuant.test.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/Nodes/CircleFakeQuant.h" + +#include "luci/IR/CircleDialect.h" +#include "luci/IR/CircleNodeVisitor.h" + +#include <gtest/gtest.h> + +TEST(CircleFakeQuantTest, constructor_P) +{ + luci::CircleFakeQuant fakequant; + + ASSERT_EQ(fakequant.dialect(), luci::CircleDialect::get()); + ASSERT_EQ(fakequant.opcode(), luci::CircleOpcode::FAKE_QUANT); + + ASSERT_EQ(nullptr, fakequant.inputs()); + ASSERT_EQ(0.0f, fakequant.min()); + ASSERT_EQ(0.0f, fakequant.max()); + ASSERT_EQ(0, fakequant.num_bits()); + ASSERT_FALSE(fakequant.narrow_range()); +} diff --git a/compiler/luci/logex/src/FormattedGraph.cpp b/compiler/luci/logex/src/FormattedGraph.cpp index b2b9cb72b..f1337e3e6 100644 --- a/compiler/luci/logex/src/FormattedGraph.cpp +++ b/compiler/luci/logex/src/FormattedGraph.cpp @@ -146,7 +146,9 @@ std::string circle_opname(uint32_t opnum) #define CIRCLE_NODE(OPCODE, CLASS) \ case luci::CircleOpcode::OPCODE: \ return prefix + #OPCODE; +#define CIRCLE_VNODE CIRCLE_NODE #include <luci/IR/CircleNodes.lst> +#undef CIRCLE_VNODE #undef CIRCLE_NODE default: break; @@ -175,7 +177,9 @@ protected: s.state(locop::NodeSummary::State::PartiallyKnown); \ return true; \ } +#define CIRCLE_VNODE CIRCLE_NODE #include <luci/IR/CircleNodes.lst> +#undef CIRCLE_VNODE #undef CIRCLE_NODE protected: @@ -205,6 +209,7 @@ private: IMPLEMENT(luci::CircleAveragePool2D) IMPLEMENT(luci::CircleBatchMatMul) IMPLEMENT(luci::CircleBatchToSpaceND) + IMPLEMENT(luci::CircleBidirectionalSequenceLSTM) IMPLEMENT(luci::CircleCast) IMPLEMENT(luci::CircleCeil) IMPLEMENT(luci::CircleConcatenation) @@ -219,6 +224,7 @@ private: IMPLEMENT(luci::CircleElu) IMPLEMENT(luci::CircleExp) IMPLEMENT(luci::CircleExpandDims) + IMPLEMENT(luci::CircleFakeQuant) IMPLEMENT(luci::CircleFill) IMPLEMENT(luci::CircleFloor) IMPLEMENT(luci::CircleFloorDiv) @@ -433,6 +439,96 @@ bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBatchToSpaceN return true; } +bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBidirectionalSequenceLSTM *node, + locop::NodeSummary &s) +{ + s.args().append("input", tbl->lookup(node->input())); + + s.args().append("fw_input_to_input_weights", tbl->lookup(node->fw_input_to_input_weights())); + s.args().append("fw_input_to_forget_weights", tbl->lookup(node->fw_input_to_forget_weights())); + s.args().append("fw_input_to_cell_weights", tbl->lookup(node->fw_input_to_cell_weights())); + s.args().append("fw_input_to_output_weights", tbl->lookup(node->fw_input_to_output_weights())); + + s.args().append("fw_recurrent_to_input_weights", + tbl->lookup(node->fw_recurrent_to_input_weights())); + s.args().append("fw_recurrent_to_forget_weights", + tbl->lookup(node->fw_recurrent_to_forget_weights())); + s.args().append("fw_recurrent_to_cell_weights", + tbl->lookup(node->fw_recurrent_to_cell_weights())); + s.args().append("fw_recurrent_to_output_weights", + tbl->lookup(node->fw_recurrent_to_output_weights())); + + s.args().append("fw_cell_to_input_weights", tbl->lookup(node->fw_cell_to_input_weights())); + s.args().append("fw_cell_to_forget_weights", tbl->lookup(node->fw_cell_to_forget_weights())); + s.args().append("fw_cell_to_output_weights", tbl->lookup(node->fw_cell_to_output_weights())); + + s.args().append("fw_input_gate_bias", tbl->lookup(node->fw_input_gate_bias())); + s.args().append("fw_forget_gate_bias", tbl->lookup(node->fw_forget_gate_bias())); + s.args().append("fw_cell_gate_bias", tbl->lookup(node->fw_cell_gate_bias())); + s.args().append("fw_output_gate_bias", tbl->lookup(node->fw_output_gate_bias())); + + s.args().append("fw_projection_weights", tbl->lookup(node->fw_projection_weights())); + s.args().append("fw_projection_bias", tbl->lookup(node->fw_projection_bias())); + + s.args().append("bw_input_to_input_weights", tbl->lookup(node->bw_input_to_input_weights())); + s.args().append("bw_input_to_forget_weights", tbl->lookup(node->bw_input_to_forget_weights())); + s.args().append("bw_input_to_cell_weights", tbl->lookup(node->bw_input_to_cell_weights())); + s.args().append("bw_input_to_output_weights", tbl->lookup(node->bw_input_to_output_weights())); + + s.args().append("bw_recurrent_to_input_weights", + tbl->lookup(node->bw_recurrent_to_input_weights())); + s.args().append("bw_recurrent_to_forget_weights", + tbl->lookup(node->bw_recurrent_to_forget_weights())); + s.args().append("bw_recurrent_to_cell_weights", + tbl->lookup(node->bw_recurrent_to_cell_weights())); + s.args().append("bw_recurrent_to_output_weights", + tbl->lookup(node->bw_recurrent_to_output_weights())); + + s.args().append("bw_cell_to_input_weights", tbl->lookup(node->bw_cell_to_input_weights())); + s.args().append("bw_cell_to_forget_weights", tbl->lookup(node->bw_cell_to_forget_weights())); + s.args().append("bw_cell_to_output_weights", tbl->lookup(node->bw_cell_to_output_weights())); + + s.args().append("bw_input_gate_bias", tbl->lookup(node->bw_input_gate_bias())); + s.args().append("bw_forget_gate_bias", tbl->lookup(node->bw_forget_gate_bias())); + s.args().append("bw_cell_gate_bias", tbl->lookup(node->bw_cell_gate_bias())); + s.args().append("bw_output_gate_bias", tbl->lookup(node->bw_output_gate_bias())); + + s.args().append("bw_projection_weights", tbl->lookup(node->bw_projection_weights())); + s.args().append("bw_projection_bias", tbl->lookup(node->bw_projection_bias())); + + s.args().append("fw_activation_state", tbl->lookup(node->fw_activation_state())); + s.args().append("fw_cell_state", tbl->lookup(node->fw_cell_state())); + s.args().append("bw_activation_state", tbl->lookup(node->bw_activation_state())); + s.args().append("bw_cell_state", tbl->lookup(node->bw_cell_state())); + + s.args().append("auxillary_input", tbl->lookup(node->auxillary_input())); + s.args().append("fw_auxillary_input_to_input_weights", + tbl->lookup(node->fw_auxillary_input_to_input_weights())); + s.args().append("fw_auxillary_input_to_forget_weights", + tbl->lookup(node->fw_auxillary_input_to_forget_weights())); + s.args().append("fw_auxillary_input_to_cell_weights", + tbl->lookup(node->fw_auxillary_input_to_cell_weights())); + s.args().append("fw_auxillary_input_to_output_weights", + tbl->lookup(node->fw_auxillary_input_to_output_weights())); + s.args().append("bw_auxillary_input_to_input_weights", + tbl->lookup(node->bw_auxillary_input_to_input_weights())); + s.args().append("bw_auxillary_input_to_forget_weights", + tbl->lookup(node->bw_auxillary_input_to_forget_weights())); + s.args().append("bw_auxillary_input_to_cell_weights", + tbl->lookup(node->bw_auxillary_input_to_cell_weights())); + s.args().append("bw_auxillary_input_to_output_weights", + tbl->lookup(node->bw_auxillary_input_to_output_weights())); + + s.args().append("cell_clip", to_str(node->cell_clip())); + s.args().append("proj_clip", to_str(node->proj_clip())); + s.args().append("merge_outputs", to_str(node->merge_outputs())); + s.args().append("time_major", to_str(node->time_major())); + s.args().append("asymmetric_quantize_inputs", to_str(node->asymmetric_quantize_inputs())); + + s.state(locop::NodeSummary::State::Complete); + return true; +} + bool summary_node(const locop::SymbolTable *tbl, const luci::CircleCast *node, locop::NodeSummary &s) { @@ -521,6 +617,18 @@ bool summary_node(const locop::SymbolTable *tbl, const luci::CircleExpandDims *n return true; } +bool summary_node(const locop::SymbolTable *tbl, const luci::CircleFakeQuant *node, + locop::NodeSummary &s) +{ + s.args().append("inputs", tbl->lookup(node->inputs())); + s.args().append("min", pepper::str(node->min())); + s.args().append("max", pepper::str(node->max())); + s.args().append("num_bits", pepper::str(node->num_bits())); + s.args().append("narrow_range", node->narrow_range() ? "true" : "false"); + s.state(locop::NodeSummary::State::Complete); + return true; +} + bool summary_node(const locop::SymbolTable *tbl, const luci::CircleFill *node, locop::NodeSummary &s) { @@ -1189,7 +1297,9 @@ bool CircleNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSumm s.comments().append("Mem = " + ptr_to_str(node)); \ return summary(dynamic_cast<const CLASS *>(node), s); \ } +#define CIRCLE_VNODE CIRCLE_NODE #include <luci/IR/CircleNodes.lst> +#undef CIRCLE_VNODE #undef CIRCLE_NODE return false; @@ -1238,6 +1348,12 @@ bool CircleNodeSummaryBuilder::summary(const luci::CircleBatchToSpaceND *node, return summary_node(tbl(), node, s); } +bool CircleNodeSummaryBuilder::summary(const luci::CircleBidirectionalSequenceLSTM *node, + locop::NodeSummary &s) const +{ + return summary_node(tbl(), node, s); +} + bool CircleNodeSummaryBuilder::summary(const luci::CircleCast *node, locop::NodeSummary &s) const { return summary_node(tbl(), node, s); @@ -1314,6 +1430,17 @@ bool CircleNodeSummaryBuilder::summary(const luci::CircleExpandDims *node, return summary_node(tbl(), node, s); } +bool CircleNodeSummaryBuilder::summary(const luci::CircleFakeQuant *node, + locop::NodeSummary &s) const +{ + return summary_node(tbl(), node, s); +} + +bool CircleNodeSummaryBuilder::summary(const luci::CircleFill *node, locop::NodeSummary &s) const +{ + return summary_node(tbl(), node, s); +} + bool CircleNodeSummaryBuilder::summary(const luci::CircleFloor *node, locop::NodeSummary &s) const { return use_x(tbl(), node, s); @@ -1331,11 +1458,6 @@ bool CircleNodeSummaryBuilder::summary(const luci::CircleFloorMod *node, return use_xy(tbl(), node, s); } -bool CircleNodeSummaryBuilder::summary(const luci::CircleFill *node, locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - bool CircleNodeSummaryBuilder::summary(const luci::CircleFullyConnected *node, locop::NodeSummary &s) const { diff --git a/compiler/luci/partition/CMakeLists.txt b/compiler/luci/partition/CMakeLists.txt new file mode 100644 index 000000000..838642b6e --- /dev/null +++ b/compiler/luci/partition/CMakeLists.txt @@ -0,0 +1,29 @@ +file(GLOB_RECURSE SOURCES "src/*.cpp") +file(GLOB_RECURSE TESTS "src/*.test.cpp") +list(REMOVE_ITEM SOURCES ${TESTS}) + +add_library(luci_partition SHARED ${SOURCES}) +target_include_directories(luci_partition PRIVATE src) +target_include_directories(luci_partition PUBLIC include) +target_link_libraries(luci_partition PUBLIC luci_lang) +target_link_libraries(luci_partition PRIVATE luci_service) +target_link_libraries(luci_partition PRIVATE luci_log) +target_link_libraries(luci_partition PRIVATE luci_logex) +target_link_libraries(luci_partition PRIVATE mio_circle) +target_link_libraries(luci_partition PRIVATE nncc_common) +target_link_libraries(luci_partition PRIVATE oops) + +install(TARGETS luci_partition DESTINATION lib) + +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +nnas_find_package(GTest REQUIRED) + +GTest_AddTest(luci_partition_test ${TESTS}) +target_include_directories(luci_partition_test PRIVATE src) +target_link_libraries(luci_partition_test luci_lang) +target_link_libraries(luci_partition_test luci_partition) +target_link_libraries(luci_partition_test luci_testhelper) +target_link_libraries(luci_partition_test luci_service) diff --git a/compiler/luci/partition/README.md b/compiler/luci/partition/README.md new file mode 100644 index 000000000..40a46bc56 --- /dev/null +++ b/compiler/luci/partition/README.md @@ -0,0 +1,4 @@ +# luci-partition + +`luci-partition` provides partition of a model to two or more sub models and +its connection configuration having same computational results. diff --git a/compiler/luci/partition/include/luci/Partition.h b/compiler/luci/partition/include/luci/Partition.h new file mode 100644 index 000000000..cf90e448b --- /dev/null +++ b/compiler/luci/partition/include/luci/Partition.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PARTITION_H__ +#define __LUCI_PARTITION_H__ + +#include <luci/IR/Module.h> + +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +namespace luci +{ + +/** + * @brief PartitionTable holds partition information + */ +struct PartitionTable +{ + std::vector<std::string> groups; + std::string default_group; + + // assign by opcode name: OPCODENAME=group + std::unordered_map<std::string /* OPCODENAME */, std::string /* group */> byopcodes; + + // TODO add assign by OP name +}; + +/** + * @brief PartedModule holds partitioned module and group name + */ +struct PartedModule +{ + std::unique_ptr<Module> module; + // group name used to partition this module + std::string group; + + // unique name(filename) of this module + std::string name; +}; + +struct PartedModules +{ + std::vector<PartedModule> pmodules; + + // TODO add connections ? +}; + +/** + * @brief Method to do paritioning from module and PartitionTable to produce PartedModules + */ +PartedModules apply(Module *module, const PartitionTable &partition); + +} // namespace luci + +#endif // __LUCI_PARTITION_H__ diff --git a/compiler/luci/partition/src/CircleOpCode.cpp b/compiler/luci/partition/src/CircleOpCode.cpp new file mode 100644 index 000000000..86694fa40 --- /dev/null +++ b/compiler/luci/partition/src/CircleOpCode.cpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleOpCode.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> + +#include <mio/circle/schema_generated.h> + +namespace +{ + +using namespace luci; +using namespace circle; + +class QueryOpCode final : public CircleNodeVisitor<BuiltinOperator> +{ +public: +// NOTE only circle operator may have BuiltinOperator_XXX +#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \ + BuiltinOperator visit(const CIRCLE_CLASS *) final { return BuiltinOperator_##OPCODE; } +#define CIRCLE_VNODE(OPCODE, CIRCLE_CLASS) + +#include "luci/IR/CircleNodes.lst" +#undef CIRCLE_VNODE +#undef CIRCLE_NODE + + // NOTE only builtin operators should be called (NOT virtual nodes) +}; + +class QueryCircleName final : public luci::CircleNodeVisitor<const char *> +{ +public: +// NOTE provide names for circle virtual nodes +#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) +#define CIRCLE_VNODE(OPCODE, CIRCLE_CLASS) \ + const char *visit(const CIRCLE_CLASS *) final { return #OPCODE; } + +#include "luci/IR/CircleNodes.lst" +#undef CIRCLE_VNODE +#undef CIRCLE_NODE + + // default is null + const char *visit(const luci::CircleNode *) final { return nullptr; } +}; + +} // namespace + +namespace luci +{ + +std::string opcode_name(const CircleNode *node) +{ + QueryCircleName qcn; + auto cname = node->accept(&qcn); + if (cname != nullptr) + return std::string(cname); + + QueryOpCode qoc; + auto opcode = node->accept(&qoc); + auto name = circle::EnumNameBuiltinOperator(opcode); + return std::string(name); +} + +} // namespace luci diff --git a/compiler/luci/lang/src/CircleShapeSignature.cpp b/compiler/luci/partition/src/CircleOpCode.h index 970000203..d17b09261 100644 --- a/compiler/luci/lang/src/CircleShapeSignature.cpp +++ b/compiler/luci/partition/src/CircleOpCode.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,21 +14,18 @@ * limitations under the License. */ -#include "luci/IR/CircleShapeSignature.h" +#ifndef __LUCI_PARTITION_CIRCLE_OP_CODE_H__ +#define __LUCI_PARTITION_CIRCLE_OP_CODE_H__ -namespace luci -{ +#include <luci/IR/CircleNode.h> -bool operator==(const ShapeSignature &lhs, const ShapeSignature &rhs) -{ - if (lhs.rank() != rhs.rank()) - return false; +#include <string> - for (uint32_t i = 0; i < lhs.rank(); ++i) - if (lhs.dim(i) != rhs.dim(i)) - return false; +namespace luci +{ - return true; -} +std::string opcode_name(const CircleNode *node); } // namespace luci + +#endif // __LUCI_PARTITION_CIRCLE_OP_CODE_H__ diff --git a/compiler/luci/partition/src/CircleOpCode.test.cpp b/compiler/luci/partition/src/CircleOpCode.test.cpp new file mode 100644 index 000000000..d2524a2ef --- /dev/null +++ b/compiler/luci/partition/src/CircleOpCode.test.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleOpCode.h" + +// NOTE any node will do for testing +#include <luci/IR/Nodes/CircleSqrt.h> + +#include <gtest/gtest.h> + +TEST(CircleOpCodeTest, name) +{ + auto g = loco::make_graph(); + auto node = g->nodes()->create<luci::CircleSqrt>(); + + auto name = luci::opcode_name(node); + ASSERT_EQ(name, "SQRT"); +} diff --git a/compiler/luci/partition/src/ConnectNode.cpp b/compiler/luci/partition/src/ConnectNode.cpp new file mode 100644 index 000000000..336be7c57 --- /dev/null +++ b/compiler/luci/partition/src/ConnectNode.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include <oops/UserExn.h> + +namespace luci +{ + +void clone_connect(const luci::CircleNode *node, luci::CloneContext &clonecontext) +{ + ConnectNode cn(clonecontext); + node->accept(&cn); +} + +luci::CircleNode *ConnectNode::find_clone(const luci::CircleNode *node) +{ + auto it = _clonecontext.find(node); + if (it == _clonecontext.end()) + throw oops::UserExn("Invalid node in ConnectNode"); + return it->second; +} + +} // namespace luci diff --git a/compiler/luci/partition/src/ConnectNode.h b/compiler/luci/partition/src/ConnectNode.h new file mode 100644 index 000000000..017c587e5 --- /dev/null +++ b/compiler/luci/partition/src/ConnectNode.h @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PARTITION_CONNECT_NODE_H__ +#define __LUCI_PARTITION_CONNECT_NODE_H__ + +#include <luci/IR/CircleNode.h> +#include <luci/IR/CircleNodeVisitor.h> + +namespace luci +{ + +/** + * @note MapNode2Clone is used as a map from original node to cloned node + * to find input of a cloned node + * + * (Original) (Clone) + * + * [A] [A'] + * | [B] | [B'] + * | | | | + * \ / \ / + * [C] [C'] + * + * From view of [C'] we need to find [A'] and [B']. We know [C] from [C'], + * then we can get from input of [C] as [A], [B] then [A]->[A'] and [B]->[B'] + * from the map. + */ +using MapNode2Clone = std::map<const CircleNode * /* ORG */, CircleNode * /* CLONE */>; + +struct CloneContext +{ + std::pair<MapNode2Clone::iterator, bool> emplace(const CircleNode *org, CircleNode *clone) + { + return node2clone.emplace(org, clone); + } + MapNode2Clone::iterator find(const CircleNode *org) { return node2clone.find(org); } + MapNode2Clone::iterator end(void) { return node2clone.end(); } + + MapNode2Clone node2clone; +}; + +class ConnectNode final : public luci::CircleNodeVisitor<void> +{ +public: + ConnectNode(luci::CloneContext &clonecontext) : _clonecontext(clonecontext){}; + +public: + // void visit(const luci::CircleAbs *) final; + void visit(const luci::CircleAdd *) final; + // void visit(const luci::CircleAddN *) final; + // void visit(const luci::CircleArgMax *) final; + // void visit(const luci::CircleArgMin *) final; + // void visit(const luci::CircleAveragePool2D *) final; + // void visit(const luci::CircleBatchMatMul *) final; + // void visit(const luci::CircleBatchToSpaceND *) final; + // void visit(const luci::CircleCast *) final; + // void visit(const luci::CircleCeil *) final; + // void visit(const luci::CircleConcatenation *) final; + void visit(const luci::CircleConst *) final; + // void visit(const luci::CircleConv2D *) final; + // void visit(const luci::CircleCos *) final; + // void visit(const luci::CircleCustom *) final; + // void visit(const luci::CircleDepthToSpace *) final; + // void visit(const luci::CircleDepthwiseConv2D *) final; + // void visit(const luci::CircleDequantize *) final; + void visit(const luci::CircleDiv *) final; + // void visit(const luci::CircleElu *) final; + // void visit(const luci::CircleEqual *) final; + // void visit(const luci::CircleExp *) final; + // void visit(const luci::CircleExpandDims *) final; + // void visit(const luci::CircleFakeQuant *) final; + // void visit(const luci::CircleFill *) final; + // void visit(const luci::CircleFloor *) final; + // void visit(const luci::CircleFloorDiv *) final; + // void visit(const luci::CircleFloorMod *) final; + // void visit(const luci::CircleFullyConnected *) final; + // void visit(const luci::CircleGather *) final; + // void visit(const luci::CircleGatherNd *) final; + // void visit(const luci::CircleGreater *) final; + // void visit(const luci::CircleGreaterEqual *) final; + // void visit(const luci::CircleIf *) final; + // void visit(const luci::CircleL2Normalize *) final; + // void visit(const luci::CircleL2Pool2D *) final; + // void visit(const luci::CircleLeakyRelu *) final; + // void visit(const luci::CircleLess *) final; + // void visit(const luci::CircleLessEqual *) final; + // void visit(const luci::CircleLocalResponseNormalization *) final; + // void visit(const luci::CircleLog *) final; + // void visit(const luci::CircleLogicalAnd *) final; + // void visit(const luci::CircleLogicalNot *) final; + // void visit(const luci::CircleLogicalOr *) final; + // void visit(const luci::CircleLogistic *) final; + // void visit(const luci::CircleLogSoftmax *) final; + // void visit(const luci::CircleMatrixDiag *) final; + // void visit(const luci::CircleMatrixSetDiag *) final; + // void visit(const luci::CircleMaximum *) final; + // void visit(const luci::CircleMaxPool2D *) final; + void visit(const luci::CircleMean *) final; + // void visit(const luci::CircleMinimum *) final; + // void visit(const luci::CircleMirrorPad *) final; + void visit(const luci::CircleMul *) final; + // void visit(const luci::CircleNeg *) final; + // void visit(const luci::CircleNonMaxSuppressionV4 *) final; + // void visit(const luci::CircleNonMaxSuppressionV5 *) final; + // void visit(const luci::CircleNotEqual *) final; + // void visit(const luci::CircleOneHot *) final; + // void visit(const luci::CirclePack *) final; + // void visit(const luci::CirclePad *) final; + // void visit(const luci::CirclePadV2 *) final; + void visit(const luci::CirclePow *) final; + // void visit(const luci::CirclePRelu *) final; + // void visit(const luci::CircleRange *) final; + // void visit(const luci::CircleRank *) final; + // void visit(const luci::CircleReduceAny *) final; + // void visit(const luci::CircleReduceMax *) final; + // void visit(const luci::CircleReduceMin *) final; + // void visit(const luci::CircleReduceProd *) final; + // void visit(const luci::CircleRelu *) final; + // void visit(const luci::CircleRelu6 *) final; + // void visit(const luci::CircleReluN1To1 *) final; + // void visit(const luci::CircleReshape *) final; + // void visit(const luci::CircleResizeBilinear *) final; + // void visit(const luci::CircleResizeNearestNeighbor *) final; + // void visit(const luci::CircleReverseSequence *) final; + // void visit(const luci::CircleReverseV2 *) final; + // void visit(const luci::CircleRound *) final; + void visit(const luci::CircleRsqrt *) final; + // void visit(const luci::CircleScatterNd *) final; + // void visit(const luci::CircleSegmentSum *) final; + // void visit(const luci::CircleSelect *) final; + // void visit(const luci::CircleSelectV2 *) final; + // void visit(const luci::CircleShape *) final; + // void visit(const luci::CircleSin *) final; + // void visit(const luci::CircleSlice *) final; + // void visit(const luci::CircleSoftmax *) final; + // void visit(const luci::CircleSpaceToBatchND *) final; + // void visit(const luci::CircleSpaceToDepth *) final; + // void visit(const luci::CircleSparseToDense *) final; + // void visit(const luci::CircleSplit *) final; + // void visit(const luci::CircleSplitV *) final; + void visit(const luci::CircleSqrt *) final; + // void visit(const luci::CircleSquare *) final; + void visit(const luci::CircleSquaredDifference *) final; + // void visit(const luci::CircleSqueeze *) final; + // void visit(const luci::CircleStridedSlice *) final; + void visit(const luci::CircleSub *) final; + // void visit(const luci::CircleSum *) final; + // void visit(const luci::CircleTanh *) final; + // void visit(const luci::CircleTile *) final; + // void visit(const luci::CircleTopKV2 *) final; + // void visit(const luci::CircleTranspose *) final; + // void visit(const luci::CircleTransposeConv *) final; + // void visit(const luci::CircleUnidirectionalSequenceLSTM *) final; + // void visit(const luci::CircleUnique *) final; + // void visit(const luci::CircleUnpack *) final; + // void visit(const luci::CircleWhere *) final; + // void visit(const luci::CircleWhile *) final; + // void visit(const luci::CircleZerosLike *) final; + + // Circle Only + // void visit(const luci::CircleBCQFullyConnected *) final; + // void visit(const luci::CircleBCQGather *) final; + // void visit(const luci::CircleInstanceNorm *) final; + + // Virtual + // void visit(const luci::CircleCustomOut *) final; + // void visit(const luci::CircleIfOut *) final; + // void visit(const luci::CircleInput *) final; + // void visit(const luci::CircleNonMaxSuppressionV4Out *) final; + // void visit(const luci::CircleNonMaxSuppressionV5Out *) final; + // void visit(const luci::CircleOutput *) final; + // void visit(const luci::CircleOutputDummy *) final; + // void visit(const luci::CircleOutputExclude *) final; + // void visit(const luci::CircleSplitOut *) final; + // void visit(const luci::CircleSplitVOut *) final; + // void visit(const luci::CircleTopKV2Out *) final; + // void visit(const luci::CircleUniqueOut *) final; + // void visit(const luci::CircleUnpackOut *) final; + // void visit(const luci::CircleWhileOut *) final; + +public: + luci::CircleNode *find_clone(const luci::CircleNode *node); + +protected: + luci::CloneContext &_clonecontext; +}; + +/** + * @brief Connect cloned node from input node + */ +void clone_connect(const luci::CircleNode *node, luci::CloneContext &clonecontext); + +} // namespace luci + +#endif // __LUCI_PARTITION_CONNECT_NODE_H__ diff --git a/compiler/luci/partition/src/ConnectNode.test.cpp b/compiler/luci/partition/src/ConnectNode.test.cpp new file mode 100644 index 000000000..a2009c654 --- /dev/null +++ b/compiler/luci/partition/src/ConnectNode.test.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.test.h" + +// This file validates "ConnectNode.test.h". Please DO NOT remove this file. diff --git a/compiler/luci/partition/src/ConnectNode.test.h b/compiler/luci/partition/src/ConnectNode.test.h new file mode 100644 index 000000000..f7333ff99 --- /dev/null +++ b/compiler/luci/partition/src/ConnectNode.test.h @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONNECT_NODE_TEST_H__ +#define __CONNECT_NODE_TEST_H__ + +#include "ConnectNode.h" + +#include <luci/Service/CircleNodeClone.h> +#include <luci/test/TestIOGraph.h> + +#include <loco/IR/Graph.h> + +#include <initializer_list> +#include <memory> +#include <stdexcept> +#include <vector> + +namespace luci +{ +namespace test +{ + +template <unsigned N> class TestIsOGraph : public TestIsGraphlet<N>, public TestOGraphlet +{ +public: + TestIsOGraph() = default; + +public: + virtual void init(const std::initializer_list<ShapeU32> shape_in, const ShapeU32 shape_out) + { + if (shape_in.size() != N) + throw std::runtime_error("Failed to init TestIsOGraph"); + + TestIsGraphlet<N>::init(TestIsGraphlet<N>::g(), shape_in); + TestOGraphlet::init(TestIsGraphlet<N>::g(), shape_out); + } +}; + +template <class T> class NodeGraphletT +{ +public: + virtual void init(loco::Graph *g) + { + _node = g->nodes()->create<T>(); + _node->dtype(loco::DataType::S32); + _node->name("node"); + } + + T *node(void) const { return _node; } + +protected: + T *_node{nullptr}; +}; + +template <class T> class NodeIsGraphletT +{ +public: + virtual void init(loco::Graph *g, uint32_t n) + { + _node = g->nodes()->create<T>(n); + _node->dtype(loco::DataType::S32); + _node->name("node"); + } + + T *node(void) const { return _node; } + +protected: + T *_node{nullptr}; +}; + +/** + * @brief ConnectionTestHelper provides common framework for testing + * cloned CircleNode connection + */ +class ConnectionTestHelper +{ +public: + ConnectionTestHelper() { _graph_clone = loco::make_graph(); } + +public: + template <unsigned N> void prepare_inputs(TestIsOGraph<N> *isograph) + { + assert(N == isograph->num_inputs()); + + for (uint32_t i = 0; i < N; ++i) + { + auto *input = _graph_clone->nodes()->create<luci::CircleInput>(); + luci::copy_common_attributes(isograph->input(i), input); + _clonectx.emplace(isograph->input(i), input); + _inputs.push_back(input); + } + } + + /** + * @note prepare_inputs_miss is for negative testing + */ + template <unsigned N> void prepare_inputs_miss(TestIsOGraph<N> *isograph) + { + assert(N == isograph->num_inputs()); + + for (uint32_t i = 0; i < N; ++i) + { + auto *input = _graph_clone->nodes()->create<luci::CircleInput>(); + luci::copy_common_attributes(isograph->input(i), input); + if (i != 0) + _clonectx.emplace(isograph->input(i), input); + _inputs.push_back(input); + } + } + + void clone_connect(luci::CircleNode *node, luci::CircleNode *clone) + { + _clonectx.emplace(node, clone); + + luci::clone_connect(node, _clonectx); + } + +public: + loco::Graph *graph_clone(void) { return _graph_clone.get(); } + + luci::CircleNode *inputs(uint32_t idx) { return _inputs.at(idx); } + +protected: + luci::CloneContext _clonectx; + std::vector<luci::CircleInput *> _inputs; + std::unique_ptr<loco::Graph> _graph_clone; // graph for clones +}; + +} // namespace test +} // namespace luci + +#endif // __CONNECT_NODE_TEST_H__ diff --git a/compiler/luci/partition/src/Nodes/CircleAdd.cpp b/compiler/luci/partition/src/Nodes/CircleAdd.cpp new file mode 100644 index 000000000..d393997e9 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleAdd.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleAdd *node) +{ + auto *cloned = loco::must_cast<luci::CircleAdd *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleAdd *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleAdd.test.cpp b/compiler/luci/partition/src/Nodes/CircleAdd.test.cpp new file mode 100644 index 000000000..e457b83d2 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleAdd.test.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleAdd> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleAdd>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Add) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAdd *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAdd *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Add_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAdd *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAdd *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/service/src/Nodes/CircleInput.cpp b/compiler/luci/partition/src/Nodes/CircleConst.cpp index 24eab7bd6..118cd8de2 100644 --- a/compiler/luci/service/src/Nodes/CircleInput.cpp +++ b/compiler/luci/partition/src/Nodes/CircleConst.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,14 @@ * limitations under the License. */ -#include <luci/Service/CircleShapeSignatureInference.h> +#include "ConnectNode.h" namespace luci { -ShapeSignature ssinf::Algorithm::visit(const luci::CircleInput *node) +void ConnectNode::visit(const luci::CircleConst *) { - return node->shape_signature(); + // Nothing to do } } // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleDiv.cpp b/compiler/luci/partition/src/Nodes/CircleDiv.cpp new file mode 100644 index 000000000..480338542 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleDiv.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleDiv *node) +{ + auto *cloned = loco::must_cast<luci::CircleDiv *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleDiv *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleDiv.test.cpp b/compiler/luci/partition/src/Nodes/CircleDiv.test.cpp new file mode 100644 index 000000000..226932337 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleDiv.test.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleDiv> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleDiv>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Div) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDiv *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDiv *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Div_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDiv *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDiv *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleMean.cpp b/compiler/luci/partition/src/Nodes/CircleMean.cpp new file mode 100644 index 000000000..b634e5838 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMean.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleMean *node) +{ + auto *cloned = loco::must_cast<luci::CircleMean *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *reduction_indices = + loco::must_cast<luci::CircleNode *>(node->reduction_indices()); + + cloned->input(cn->find_clone(input)); + cloned->reduction_indices(cn->find_clone(reduction_indices)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleMean *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleMul.cpp b/compiler/luci/partition/src/Nodes/CircleMul.cpp new file mode 100644 index 000000000..2cd2b4038 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMul.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleMul *node) +{ + auto *cloned = loco::must_cast<luci::CircleMul *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleMul *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleMul.test.cpp b/compiler/luci/partition/src/Nodes/CircleMul.test.cpp new file mode 100644 index 000000000..99cf0824d --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMul.test.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleMul> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) + { + NodeGraphletT<luci::CircleMul>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Mul) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMul *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMul *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Mul_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMul *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMul *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CirclePow.cpp b/compiler/luci/partition/src/Nodes/CirclePow.cpp new file mode 100644 index 000000000..fb180ee69 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CirclePow.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CirclePow *node) +{ + auto *cloned = loco::must_cast<luci::CirclePow *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CirclePow *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleRsqrt.cpp b/compiler/luci/partition/src/Nodes/CircleRsqrt.cpp new file mode 100644 index 000000000..03e64aad0 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRsqrt.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleRsqrt *node) +{ + auto *cloned = loco::must_cast<luci::CircleRsqrt *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleRsqrt *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSqrt.cpp b/compiler/luci/partition/src/Nodes/CircleSqrt.cpp new file mode 100644 index 000000000..f737aac8d --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSqrt.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSqrt *node) +{ + auto *cloned = loco::must_cast<luci::CircleSqrt *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSqrt *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSquaredDifference.cpp b/compiler/luci/partition/src/Nodes/CircleSquaredDifference.cpp new file mode 100644 index 000000000..40dd31706 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSquaredDifference.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSquaredDifference *node) +{ + auto *cloned = loco::must_cast<luci::CircleSquaredDifference *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSquaredDifference *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSub.cpp b/compiler/luci/partition/src/Nodes/CircleSub.cpp new file mode 100644 index 000000000..8ac294b7b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSub.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSub *node) +{ + auto *cloned = loco::must_cast<luci::CircleSub *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSub *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSub.test.cpp b/compiler/luci/partition/src/Nodes/CircleSub.test.cpp new file mode 100644 index 000000000..7c0d83745 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSub.test.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSub> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) + { + NodeGraphletT<luci::CircleSub>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Sub) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSub *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSub *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Sub_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSub *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSub *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Partition.cpp b/compiler/luci/partition/src/Partition.cpp new file mode 100644 index 000000000..cc7106ca9 --- /dev/null +++ b/compiler/luci/partition/src/Partition.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PartitionIR.h" +#include "PartitionIRDump.h" +#include "PartitionPGroups.h" +#include "PartitionMerge.h" +#include "PartitionCleanup.h" +#include "PartitionPModules.h" +#include "PartitionPModulesDump.h" + +#include "luci/Partition.h" +#include "luci/Log.h" + +#include <cassert> + +namespace luci +{ + +/** + * @brief This will return Partitioned Modules object + */ +PartedModules apply(Module *source, const PartitionTable &partition) +{ + assert(source != nullptr); + + LOGGER(l); + + auto pgroups = produce_pgroups(source, partition); + INFO(l) << "--- Partition Graph (1)------------------------"; + INFO(l) << pgroups.get(); + + auto mpgroups = merge_pgroups(pgroups.get()); + INFO(l) << "--- Partition Graph (2)------------------------"; + INFO(l) << mpgroups.get(); + + remove_unused_inputoutputs(mpgroups.get(), source); + INFO(l) << "--- Partition Graph (3)------------------------"; + INFO(l) << mpgroups.get(); + + auto pmodules = produce_pmodules(mpgroups.get()); + INFO(l) << "--- Modules -----------------------------------"; + INFO(l) << &pmodules; + + return pmodules; +} + +} // namespace luci diff --git a/compiler/luci/partition/src/Partition.test.cpp b/compiler/luci/partition/src/Partition.test.cpp new file mode 100644 index 000000000..9e24c441c --- /dev/null +++ b/compiler/luci/partition/src/Partition.test.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Partition.h" + +#include <luci/test/TestIOGraph.h> + +#include <luci/IR/Nodes/CircleSqrt.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class SqrtGraphlet +{ +public: + SqrtGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 input_shape) + { + _sqrt = g->nodes()->create<luci::CircleSqrt>(); + _sqrt->dtype(loco::DataType::S32); + _sqrt->name("sqrt"); + } + +protected: + luci::CircleSqrt *_sqrt = nullptr; +}; + +class SqrtGraph : public TestIOGraph, public SqrtGraphlet +{ +public: + SqrtGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + SqrtGraphlet::init(g(), shape); + + _sqrt->x(input()); + + output()->from(_sqrt); + } +}; + +} // namespace + +TEST(PartitionTest, simple_apply) +{ + luci::Module module; + + SqrtGraph g; + g.init({3, 3}); + g.transfer_to(&module); + + luci::PartitionTable pt; + pt.default_group = "A"; + + auto pms = apply(&module, pt); + + ASSERT_EQ(1, pms.pmodules.size()); + + auto &pm = *pms.pmodules.begin(); + ASSERT_NE(nullptr, pm.module->graph()); +} diff --git a/compiler/luci/partition/src/PartitionCleanup.cpp b/compiler/luci/partition/src/PartitionCleanup.cpp new file mode 100644 index 000000000..6545295df --- /dev/null +++ b/compiler/luci/partition/src/PartitionCleanup.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PartitionCleanup.h" + +#include "luci/Log.h" + +namespace +{ + +using CircleNodes = std::vector<luci::CircleNode *>; + +/** + * @note Original source outputs should be outputs + */ +void gather_graph_outputs(CircleNodes &nodes, const luci::Module *source) +{ + // graph outputs are treated as used + auto graph = source->graph(); + for (uint32_t n = 0; n < graph->outputs()->size(); ++n) + { + auto output = luci::output_node(graph, n); // output is CircleOutput + assert(output != nullptr); + + auto node = loco::must_cast<luci::CircleNode *>(output->from()); + + nodes.push_back(node); + } + + // TODO add unused virtual outputs +} + +/** + * @note If one PGroup requires an input, that input should be an output + * from another PGroup + */ +void gather_pgroups_outputs(CircleNodes &nodes, const luci::PGroups *pgroups) +{ + // input of a pgroup is used output + for (auto &pgroup : pgroups->pgroups) + { + for (auto input : pgroup->inputs) + { + nodes.push_back(input); + } + } +} + +} // namespace + +namespace luci +{ + +void remove_unused_inputoutputs(luci::PGroups *pgroups, const luci::Module *source) +{ + assert(source != nullptr); + assert(pgroups != nullptr); + + LOGGER(l); + + // TODO support multiple subgraph + assert(source->size() == 1); + + INFO(l) << "--- Cleanup unused inputs/outputs"; + + // remove input within same pgroup + for (auto &pgroup : pgroups->pgroups) + { + bool changed; + do + { + changed = false; + for (auto it = pgroup->inputs.begin(); it != pgroup->inputs.end(); ++it) + { + auto input = *it; + if (pgroups->pgroup_of(input) == pgroup.get()) + { + INFO(l) << " Cleanup input " << input->name() << " from group " << pgroup->group; + pgroup->inputs.erase(it); + changed = true; + break; + } + // NOTE CircleConst is one of input type, as they are registered as + // input to some node and then (should be) merged. + // Remove if this input is CircleConst + if (dynamic_cast<CircleConst *>(input) != nullptr) + { + INFO(l) << " Cleanup CircleConst " << input->name() << " from group " << pgroup->group; + pgroup->inputs.erase(it); + changed = true; + break; + } + } + } while (changed); + } + + // remove unused output(s) + // 'used_outputs' will hold actual used outputs for all PGroups + CircleNodes used_outputs; + + gather_graph_outputs(used_outputs, source); + gather_pgroups_outputs(used_outputs, pgroups); + + for (auto &pgroup : pgroups->pgroups) + { + bool changed; + do + { + changed = false; + for (auto it = pgroup->outputs.begin(); it != pgroup->outputs.end(); ++it) + { + auto output = *it; + auto oit = std::find(used_outputs.begin(), used_outputs.end(), output); + if (oit == used_outputs.end()) + { + INFO(l) << " Cleanup output " << output->name() << " from group " << pgroup->group; + pgroup->outputs.erase(it); + changed = true; + break; + } + } + } while (changed); + } +} + +} // namespace luci diff --git a/compiler/luci/partition/src/PartitionCleanup.h b/compiler/luci/partition/src/PartitionCleanup.h new file mode 100644 index 000000000..f81b4a7cb --- /dev/null +++ b/compiler/luci/partition/src/PartitionCleanup.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PARTITON_CLEANUP_H__ +#define __LUCI_PARTITON_CLEANUP_H__ + +#include "PartitionIR.h" + +#include <luci/IR/Module.h> + +namespace luci +{ + +/** + * @brief This will remove unused inputs/outputs in each pgroup of pgroups + */ +void remove_unused_inputoutputs(luci::PGroups *, const luci::Module *); + +} // namespace luci + +#endif // __LUCI_PARTITON_CLEANUP_H__ diff --git a/compiler/luci/partition/src/PartitionIR.cpp b/compiler/luci/partition/src/PartitionIR.cpp new file mode 100644 index 000000000..ebd6b25fa --- /dev/null +++ b/compiler/luci/partition/src/PartitionIR.cpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PartitionIR.h" +#include "CircleOpCode.h" + +#include "luci/Log.h" + +#include <cassert> +#include <ostream> +#include <iostream> + +namespace luci +{ + +std::unique_ptr<PGroups> PGroups::make_copy(void) const +{ + auto d_pgroups = std::make_unique<luci::PGroups>(); + + for (auto &s_pgroup : pgroups) + { + // make a copy of s_pgroup to d_pgroup + std::unique_ptr<luci::PGroup> d_pgroup = std::make_unique<luci::PGroup>(); + + d_pgroup->group = s_pgroup->group; + d_pgroup->id = s_pgroup->id; + + for (auto &pnode : s_pgroup->pnodes) + { + auto pnodec = std::make_unique<luci::PNode>(); + pnodec->node = pnode->node; + pnodec->group = pnode->group; + pnodec->pgroup = d_pgroup.get(); + d_pgroup->pnodes.push_back(std::move(pnodec)); + } + + for (auto &input : s_pgroup->inputs) + d_pgroup->inputs.push_back(input); + + for (auto &output : s_pgroup->outputs) + d_pgroup->outputs.push_back(output); + + // copy node2group + for (auto it = node2group.begin(); it != node2group.end(); ++it) + d_pgroups->node2group[it->first] = it->second; + + // build id2pgroup + d_pgroups->id2pgroup[d_pgroup->id] = d_pgroup.get(); + + d_pgroups->pgroups.push_back(std::move(d_pgroup)); + // note: d_pgroup is now nullptr as it's moved + } + + return std::move(d_pgroups); +} + +std::string PGroups::group_of(luci::CircleNode *node) const +{ + assert(node != nullptr); + + LOGGER(l); + + auto it = node2group.find(node); + if (it == node2group.end()) + { + INFO(l) << "PGroups::group_of " << node << "(" << node->name() << ") not found" << std::endl; + return ""; + } + return it->second; +} + +const PGroup *PGroups::pgroup_of(luci::CircleNode *node) const +{ + assert(node != nullptr); + + for (auto &pgroup : pgroups) + { + for (auto &pnode : pgroup->pnodes) + { + if (node == pnode->node) + return pgroup.get(); + } + } + // node maybe graph input (CircleInput) + return nullptr; +} + +} // namespace luci diff --git a/compiler/luci/partition/src/PartitionIR.h b/compiler/luci/partition/src/PartitionIR.h new file mode 100644 index 000000000..852e38cc0 --- /dev/null +++ b/compiler/luci/partition/src/PartitionIR.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PARTITION_IR_H__ +#define __LUCI_PARTITION_IR_H__ + +#include <luci/IR/CircleNodes.h> + +#include <map> +#include <memory> +#include <string> +#include <vector> + +namespace luci +{ + +struct PGroup; + +/** + * @brief Partition Node with CircleNode with group name + * @note node just points to source luci::CircleNode, NOT the cloned node + * CloneContext is used to find cloned node from source node + */ +struct PNode +{ + const luci::CircleNode *node = nullptr; + std::string group; + + const PGroup *pgroup = nullptr; +}; + +/** + * @brief Partition Group with Partition Nodes of same group and I/Os nodes + */ +struct PGroup +{ + std::vector<std::unique_ptr<PNode>> pnodes; + std::string group; + uint32_t id = 0; + + // I/O while partitioning + std::vector<luci::CircleNode *> inputs; + std::vector<luci::CircleNode *> outputs; +}; + +struct PGroups +{ + std::vector<std::unique_ptr<PGroup>> pgroups; + + // node2group is to find group key from source node + std::map<const luci::CircleNode *, std::string> node2group; + + // id2pngroup is to find *pngroup from pngroup id + std::map<uint32_t, PGroup *> id2pgroup; + + // default group key for reference + std::string default_group; + +public: + /** + * @brief return a copy of PGroups + */ + std::unique_ptr<PGroups> make_copy(void) const; + + /** + * @brief return group key of node, empty string if not found + */ + std::string group_of(luci::CircleNode *node) const; + + /** + * @brief return holding pgroup of node, nullptr if not found + */ + const PGroup *pgroup_of(luci::CircleNode *node) const; +}; + +} // namespace luci + +#endif // __LUCI_PARTITION_IR_H__ diff --git a/compiler/luci/partition/src/PartitionIR.test.cpp b/compiler/luci/partition/src/PartitionIR.test.cpp new file mode 100644 index 000000000..4c051a96d --- /dev/null +++ b/compiler/luci/partition/src/PartitionIR.test.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PartitionIR.h" + +// NOTE any node will do for testing +#include <luci/IR/Nodes/CircleAdd.h> + +#include <gtest/gtest.h> + +#include <memory> + +TEST(PartitionIRTest, PNode_ctor) +{ + auto g = loco::make_graph(); + auto node = g->nodes()->create<luci::CircleAdd>(); + + luci::PNode pnode; + pnode.node = node; + + ASSERT_NE(nullptr, pnode.node); + ASSERT_EQ(nullptr, pnode.pgroup); +} + +// TODO add more tests with luci::PNode + +TEST(PartitionIRTest, PGroup_ctor) +{ + auto g = loco::make_graph(); + auto node = g->nodes()->create<luci::CircleAdd>(); + + luci::PGroup pgroup; + auto pnode = std::make_unique<luci::PNode>(); + pnode->node = node; + + pgroup.pnodes.push_back(std::move(pnode)); + + ASSERT_NE(pgroup.pnodes.end(), pgroup.pnodes.begin()); + ASSERT_EQ(0, pgroup.inputs.size()); + ASSERT_EQ(0, pgroup.outputs.size()); +} + +// TODO add more tests with luci::PGroup + +TEST(PartitionIRTest, PGroups_ctor) +{ + auto g = loco::make_graph(); + auto node = g->nodes()->create<luci::CircleAdd>(); + + auto pnode = std::make_unique<luci::PNode>(); + pnode->node = node; + + auto pgroup = std::make_unique<luci::PGroup>(); + pgroup->pnodes.push_back(std::move(pnode)); + + luci::PGroups pgroups; + pgroups.pgroups.push_back(std::move(pgroup)); + + ASSERT_NE(pgroups.pgroups.end(), pgroups.pgroups.begin()); +} + +// TODO add more tests with luci::PGroups diff --git a/compiler/luci/partition/src/PartitionIRDump.cpp b/compiler/luci/partition/src/PartitionIRDump.cpp new file mode 100644 index 000000000..4f2c26800 --- /dev/null +++ b/compiler/luci/partition/src/PartitionIRDump.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PartitionIRDump.h" + +#include "CircleOpCode.h" + +#include <iostream> + +namespace luci +{ + +void dump(std::ostream &os, const PNode *pnode) +{ + os << "PNode: " << pnode->group << ", " << pnode->node << ":" << luci::opcode_name(pnode->node) + << ":" << pnode->node->name() << std::endl; +} + +void dump(std::ostream &os, const PGroup *pgroup) +{ + os << "--- PGroup: " << pgroup->group << std::endl; + os << "Input(s): "; + for (auto &node_in : pgroup->inputs) + os << node_in->name() << " "; + os << std::endl; + for (auto &pnode : pgroup->pnodes) + { + dump(os, pnode.get()); + } + os << "Output(s): "; + for (auto &node_out : pgroup->outputs) + os << node_out->name() << " "; + os << std::endl; +} + +void dump(std::ostream &os, const PGroups *pgroups) +{ + for (auto &pgroup : pgroups->pgroups) + { + dump(os, pgroup.get()); + } + os << "--- Node2Group items: " << std::endl; + for (auto it = pgroups->node2group.begin(); it != pgroups->node2group.end(); ++it) + { + auto node = it->first; + auto group = it->second; + os << " Node: " << node << "(" << node->name() << "): " << group << std::endl; + } +} + +} // namespace luci + +std::ostream &operator<<(std::ostream &os, const luci::PGroups *pgroups) +{ + luci::dump(os, pgroups); + return os; +} diff --git a/compiler/luci/partition/src/PartitionIRDump.h b/compiler/luci/partition/src/PartitionIRDump.h new file mode 100644 index 000000000..8a4b3f579 --- /dev/null +++ b/compiler/luci/partition/src/PartitionIRDump.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PARTITION_IR_DUMP_H__ +#define __LUCI_PARTITION_IR_DUMP_H__ + +#include "PartitionIR.h" + +#include <iostream> + +namespace luci +{ + +void dump(std::ostream &os, const PNode *pnode); +void dump(std::ostream &os, const PGroup *pgroup); +void dump(std::ostream &os, const PGroups *pgroups); + +} // namespace luci + +std::ostream &operator<<(std::ostream &os, const luci::PGroups *pgroups); + +#endif // __LUCI_PARTITION_IR_DUMP_H__ diff --git a/compiler/luci/partition/src/PartitionMerge.cpp b/compiler/luci/partition/src/PartitionMerge.cpp new file mode 100644 index 000000000..038fc2a0c --- /dev/null +++ b/compiler/luci/partition/src/PartitionMerge.cpp @@ -0,0 +1,207 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PartitionMerge.h" + +#include <algorithm> + +namespace +{ + +/** + * @brief return true if pgroup_i output is one of the inputs of pgroup + */ +bool is_input_of(const luci::PGroup *pgroup_i, const luci::PGroup *pgroup) +{ + for (auto *output : pgroup_i->outputs) + { + for (auto *input : pgroup->inputs) + { + if (input == output) + return true; + } + } + return false; +} + +/** + * @brief return true if there is only one input or all the inputs have same group + * @note pgroups is used to find group of pgroup + */ +bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups) +{ + assert(pgroups != nullptr); + assert(pgroup != nullptr); + + const luci::PGroup *input_pgroup = nullptr; + std::string group; + for (auto &input : pgroup->inputs) + { + auto input_group = pgroups->group_of(input); + // NOTE: all the nodes should be registered and return should be valid group. + // convert_to_proups() should ensure this. + // assert here to find if there is any problem with this. + assert(not input_group.empty()); + if (input_group.empty()) + input_group = pgroups->default_group; + + if (group.empty()) + group = input_group; + else + { + if (group != input_group) + return false; + } + // if there are multiple inputs, all the inputs should be in same pgroup + // https://github.com/Samsung/ONE/issues/6230#issuecomment-801618150 + // https://github.com/Samsung/ONE/issues/6230#issuecomment-801680531 + auto pgroup_input = pgroups->pgroup_of(input); + if (pgroup_input != nullptr) + { + if (input_pgroup == nullptr) + input_pgroup = pgroup_input; + else + { + if (input_pgroup != pgroup_input) + return false; + } + } + } + return true; +} + +/** + * @brief merge pgroup into pgroup_i + * @note output of pgroup_i should be input of pgroup + */ +void merge_into(luci::PGroup *pgroup, luci::PGroup *pgroup_i) +{ + for (auto &pnode : pgroup->pnodes) + { + // update pgroup for this pnode + pnode->pgroup = pgroup_i; + assert(pnode->group == pgroup_i->group); + + // we don't need to add this in topological order: + // all the nodes will be created first then connection will be held + pgroup_i->pnodes.push_back(std::move(pnode)); + // note: pnode is now nullptr as it's moved into pgroup_i->pnodes + } + + for (auto &input : pgroup->inputs) + { + // add inputs of pgroup to pgroup_i if not member of pgroup_i + bool found_in_pgroup_i = false; + for (auto &pnode : pgroup_i->pnodes) + { + if (input == pnode->node) + { + found_in_pgroup_i = true; + break; + } + } + // skip if this input is already in the inputs + auto fit = std::find(pgroup_i->inputs.begin(), pgroup_i->inputs.end(), input); + if (fit != pgroup_i->inputs.end()) + { + found_in_pgroup_i = true; + } + // note: if we force found_in_pgroup_i to false, for testing there will be + // unnecessary inputs + if (not found_in_pgroup_i) + { + // node input maybe in another pgroup + pgroup_i->inputs.push_back(input); + } + } + // add outputs of pgroup to pgroup_i outputs if not exist + for (auto &output : pgroup->outputs) + { + auto it = std::find(pgroup_i->outputs.begin(), pgroup_i->outputs.end(), output); + if (it == pgroup_i->outputs.end()) + { + pgroup_i->outputs.push_back(output); + } + } +} + +} // namespace + +namespace luci +{ + +/** + * @brief This will merge pgroups with same group values in topological order + */ +std::unique_ptr<luci::PGroups> merge_pgroups(const luci::PGroups *s_pgroups) +{ + // Make a copy of pgroups to apply merge action + // Q) do we really need a copy? + auto d_pgroups = s_pgroups->make_copy(); + + // Merge partition graphs + // - This is initial implementation that works for limited networks + // - if A and B is same group -> if A is input of B -> ... -> merge B into A + auto &pgroups = d_pgroups->pgroups; + bool changed; + do + { + changed = false; + for (auto &pgroup_i : pgroups) + { + bool merged = false; + for (auto it = pgroups.begin(); it != pgroups.end(); ++it) + { + auto &pgroup = *it; + + // skip if same object + if (pgroup->id == pgroup_i->id) + continue; + // skip if different group + if (pgroup->group != pgroup_i->group) + continue; + // skip if not connected + if (!is_input_of(pgroup_i.get(), pgroup.get())) + continue; + // skip if there are multiple inputs but inputs differ in group + if (!is_input_same(pgroup.get(), d_pgroups.get())) + continue; + // TODO add more condition may be needed + + merge_into(pgroup.get(), pgroup_i.get()); + + auto eit = d_pgroups->id2pgroup.find(pgroup->id); + assert(eit != d_pgroups->id2pgroup.end()); + d_pgroups->id2pgroup.erase(eit); + + // remove merged pgroup from pgroups + pgroups.erase(it); + + merged = true; + break; + } + if (merged) + { + changed = true; + break; + } + } + } while (changed); + + return std::move(d_pgroups); +} + +} // namespace luci diff --git a/compiler/luci/partition/src/PartitionMerge.h b/compiler/luci/partition/src/PartitionMerge.h new file mode 100644 index 000000000..5c9fec2d2 --- /dev/null +++ b/compiler/luci/partition/src/PartitionMerge.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PARTITON_MERGE_H__ +#define __LUCI_PARTITON_MERGE_H__ + +#include "PartitionIR.h" + +#include <memory> + +namespace luci +{ + +std::unique_ptr<luci::PGroups> merge_pgroups(const luci::PGroups *s_pgroups); + +} // namespace luci + +#endif // __LUCI_PARTITON_MERGE_H__ diff --git a/compiler/luci/partition/src/PartitionPGroups.cpp b/compiler/luci/partition/src/PartitionPGroups.cpp new file mode 100644 index 000000000..594ed6c40 --- /dev/null +++ b/compiler/luci/partition/src/PartitionPGroups.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PartitionPGroups.h" +#include "PartitionIR.h" +#include "CircleOpCode.h" + +#include "luci/Partition.h" +#include "luci/Log.h" +#include "luci/LogHelper.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> + +#include <loco.h> + +namespace +{ + +class IsVirtualNode final : public luci::CircleNodeVisitor<bool> +{ +public: + bool visit(const luci::CircleInput *) final { return true; } + bool visit(const luci::CircleOutput *) final { return true; } + // TODO add all virtual nodes + + // default is false + bool visit(const luci::CircleNode *) final { return false; } +}; + +bool check_allocate_partition(const luci::CircleNode *node) +{ + IsVirtualNode query; + if (node->accept(&query)) + return false; + /** + * @note About CircleConst + * CirleConst acts like a part of some CircleNode and managing mulitiple + * used(referenced) CircleConst is a bit difficult if it's used across + * different PGroup. So we treat this different to other types. + * https://github.com/Samsung/ONE/issues/6230#issuecomment-809802813 + */ + if (dynamic_cast<const luci::CircleConst *>(node) != nullptr) + return false; + return true; +} + +} // namespace + +namespace luci +{ + +std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source, + const luci::PartitionTable &partition) +{ + assert(source != nullptr); + // TODO support multiple subgraphs + assert(source->size() == 1); + + LOGGER(l); + + auto pgroups = std::make_unique<luci::PGroups>(); + + pgroups->default_group = partition.default_group; + + // Create a PGroup per CircleNode: each PGroup will have one CircleNode + auto graph = source->graph(); + auto nodes = graph->nodes(); + for (uint32_t idx = 0; idx < nodes->size(); ++idx) + { + auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx)); + + // check if node is normal node that we are interested + if (check_allocate_partition(node)) + { + auto opcodename = luci::opcode_name(node); + assert(!opcodename.empty()); + + auto group = partition.default_group; + auto it = partition.byopcodes.find(opcodename); + if (it != partition.byopcodes.end()) + group = it->second; + + INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group + << std::endl; + + auto pgroup = std::make_unique<luci::PGroup>(); + pgroup->group = group; + pgroup->id = idx + 1; + + auto pnode = std::make_unique<luci::PNode>(); + pnode->node = node; + pnode->group = group; + pnode->pgroup = pgroup.get(); + + pgroup->pnodes.push_back(std::move(pnode)); + + // Set input of PGroup + for (uint32_t in = 0; in < node->arity(); ++in) + { + auto input = loco::must_cast<luci::CircleNode *>(node->arg(in)); + // this input maybe CircleInput in source graph + // --> not confident this is safe + pgroup->inputs.push_back(input); + } + // Set output of PGroup: node itself or multiple virtual outputs + // TODO support multiple virtual outputs + pgroup->outputs.push_back(node); + + pgroups->node2group[node] = group; + pgroups->id2pgroup[pgroup->id] = pgroup.get(); + + pgroups->pgroups.push_back(std::move(pgroup)); + } + else + { + INFO(l) << "Skip Op: " << node->name() << std::endl; + // record as default group + pgroups->node2group[node] = partition.default_group; + } + } + + return std::move(pgroups); +} + +} // namespace luci diff --git a/compiler/luci/partition/src/PartitionPGroups.h b/compiler/luci/partition/src/PartitionPGroups.h new file mode 100644 index 000000000..998e11cbd --- /dev/null +++ b/compiler/luci/partition/src/PartitionPGroups.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PARTITON_PGROUPS_H__ +#define __LUCI_PARTITON_PGROUPS_H__ + +#include "PartitionIR.h" + +#include "luci/Partition.h" + +#include <luci/IR/Module.h> + +namespace luci +{ + +/** + * @brief This will produce a PGroups from Module and PartitionTable. + * @note Each PGroup will hold one CircleNode and partition key value as group. + * Supports only single Graph in the Module for now. + */ +std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source, + const luci::PartitionTable &partition); + +} // namespace luci + +#endif // __LUCI_PARTITON_PGROUPS_H__ diff --git a/compiler/luci/partition/src/PartitionPGroups.test.cpp b/compiler/luci/partition/src/PartitionPGroups.test.cpp new file mode 100644 index 000000000..960f3cde9 --- /dev/null +++ b/compiler/luci/partition/src/PartitionPGroups.test.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PartitionPGroups.h" + +#include <luci/test/TestIOGraph.h> + +#include <luci/IR/Nodes/CircleSqrt.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class SqrtGraphlet +{ +public: + SqrtGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 input_shape) + { + _sqrt = g->nodes()->create<luci::CircleSqrt>(); + _sqrt->dtype(loco::DataType::S32); + _sqrt->name("sqrt"); + } + +protected: + luci::CircleSqrt *_sqrt = nullptr; +}; + +class SqrtGraph : public TestIOGraph, public SqrtGraphlet +{ +public: + SqrtGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + SqrtGraphlet::init(g(), shape); + + _sqrt->x(input()); + + output()->from(_sqrt); + } +}; + +} // namespace + +TEST(PartitionPGroupsTest, simple_produce) +{ + luci::Module module; + + SqrtGraph g; + g.init({3, 3}); + g.transfer_to(&module); + + luci::PartitionTable pt; + pt.default_group = "A"; + + auto pgs = produce_pgroups(&module, pt); + + ASSERT_EQ(1, pgs->pgroups.size()); +} diff --git a/compiler/luci/partition/src/PartitionPModules.cpp b/compiler/luci/partition/src/PartitionPModules.cpp new file mode 100644 index 000000000..36f4d47a4 --- /dev/null +++ b/compiler/luci/partition/src/PartitionPModules.cpp @@ -0,0 +1,203 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PartitionPModules.h" +#include "ConnectNode.h" + +#include "luci/Service/CircleNodeClone.h" +#include "luci/Log.h" + +#include <loco.h> + +namespace +{ + +void add_graph_input(loco::Graph *graph, luci::CircleInput *input_node) +{ + assert(graph != nullptr); + assert(input_node != nullptr); + + auto graph_input = graph->inputs()->create(); + graph_input->name(input_node->name()); + + // Set GraphInputOutputIndex for graph + input_node->index(graph_input->index()); + + // Data type + graph_input->dtype(input_node->dtype()); + + // Shape of GraphInput + auto input_shape = std::make_unique<loco::TensorShape>(); + input_shape->rank(input_node->rank()); + for (uint32_t r = 0; r < input_node->rank(); ++r) + { + if (input_node->dim(r).known()) + input_shape->dim(r).set(input_node->dim(r).value()); + } + graph_input->shape(std::move(input_shape)); +} + +void add_graph_output(loco::Graph *graph, luci::CircleOutput *output_node) +{ + assert(graph != nullptr); + assert(output_node != nullptr); + + auto graph_output = graph->outputs()->create(); + graph_output->name(output_node->name()); + + // Set GraphInputOutputIndex for graph + output_node->index(graph_output->index()); + + // Data type + graph_output->dtype(output_node->dtype()); + + // Shape of GraphOutput + auto output_shape = std::make_unique<loco::TensorShape>(); + output_shape->rank(output_node->rank()); + for (uint32_t r = 0; r < output_node->rank(); ++r) + { + if (output_node->dim(r).known()) + output_shape->dim(r).set(output_node->dim(r).value()); + } + graph_output->shape(std::move(output_shape)); +} + +/** + * @brief Build loco::graph from pgroup into graph + */ +void build_graph(loco::Graph *graph, const luci::PGroup *pgroup) +{ + LOGGER(l); + + luci::CloneContext clonectx; + + // add input node(s) + for (auto *input : pgroup->inputs) + { + auto *input_clone = graph->nodes()->create<luci::CircleInput>(); + luci::copy_common_attributes(input, input_clone); + + add_graph_input(graph, input_clone); + clonectx.emplace(input, input_clone); + + INFO(l) << "MAP: " + << " input(" << input << ") -> " << input_clone << "(" << input_clone->name() << ")"; + } + + // add CircleConst for inputs + for (auto &pnode : pgroup->pnodes) + { + auto node = pnode->node; + uint32_t arity = node->arity(); + for (uint32_t a = 0; a < arity; ++a) + { + auto in_a_const = dynamic_cast<luci::CircleConst *>(node->arg(a)); + if (in_a_const != nullptr) + { + auto it = clonectx.find(in_a_const); + if (it == clonectx.end()) + { + auto *clone = clone_node(in_a_const, graph); + clonectx.emplace(in_a_const, clone); + + INFO(l) << "MAP: " + << " const(" << in_a_const << ") -> " << clone << "(" << clone->name() << ")"; + } + } + } + } + + // add nodes + for (auto &pnode : pgroup->pnodes) + { + auto *clone = clone_node(pnode->node, graph); + clonectx.emplace(pnode->node, clone); + + INFO(l) << "MAP: " + << " node(" << pnode->node << ") -> " << clone << "(" << clone->name() << ")"; + } + // connect nodes + for (auto &pnode : pgroup->pnodes) + { + clone_connect(pnode->node, clonectx); + } + + // add output node(s) + for (auto *output : pgroup->outputs) + { + auto *output_clone = graph->nodes()->create<luci::CircleOutput>(); + luci::copy_common_attributes(output, output_clone); + // note: we don't add output_clone to clonectx. + // logically, output is not used as an input to any other nodes. + + auto it = clonectx.find(output); + assert(it != clonectx.end()); + output_clone->from(it->second); + + add_graph_output(graph, output_clone); + + INFO(l) << "MAP: " + << "output(" << output << ") -> " << output_clone << "(" << output_clone->name() << ")" + << ": from " << it->second << "(" << it->second->name() << ")"; + } +} + +std::string make_name(const luci::PGroup *pgroup) +{ + auto &first_pnode = *pgroup->pnodes.begin(); + auto *first_node = first_pnode->node; + std::string name = first_node->graph()->name(); + name = name + "_" + pgroup->group; + return name; +} + +} // namespace + +namespace luci +{ + +/** + * @brief This will produce list of luci::Module as PartedModules from pgroups + */ +luci::PartedModules produce_pmodules(const luci::PGroups *pgroups) +{ + LOGGER(l); + + luci::PartedModules pms; + + for (auto &pgroup : pgroups->pgroups) + { + luci::PartedModule pm; + pm.module = std::make_unique<luci::Module>(); + pm.group = pgroup->group; + + auto graph = loco::make_graph(); + + auto graph_name = make_name(pgroup.get()); + graph->name(graph_name); + + INFO(l) << "--- Partition Graph build----------------------"; + INFO(l) << "--- name: " << graph_name; + build_graph(graph.get(), pgroup.get()); + + pm.module->add(std::move(graph)); + pms.pmodules.emplace_back(std::move(pm)); + } + + return pms; +} + +} // namespace luci diff --git a/compiler/luci/partition/src/PartitionPModules.h b/compiler/luci/partition/src/PartitionPModules.h new file mode 100644 index 000000000..628ada56c --- /dev/null +++ b/compiler/luci/partition/src/PartitionPModules.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PARTITON_PMODULES_H__ +#define __LUCI_PARTITON_PMODULES_H__ + +#include "PartitionIR.h" + +#include "luci/Partition.h" + +namespace luci +{ + +luci::PartedModules produce_pmodules(const luci::PGroups *pgroups); + +} // namespace luci + +#endif // __LUCI_PARTITON_PMODULES_H__ diff --git a/compiler/luci/partition/src/PartitionPModules.test.cpp b/compiler/luci/partition/src/PartitionPModules.test.cpp new file mode 100644 index 000000000..99c39e839 --- /dev/null +++ b/compiler/luci/partition/src/PartitionPModules.test.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PartitionPModules.h" +#include "PartitionPGroups.h" + +#include <luci/test/TestIOGraph.h> + +#include <luci/IR/Nodes/CircleSqrt.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class SqrtGraphlet +{ +public: + SqrtGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 input_shape) + { + _sqrt = g->nodes()->create<luci::CircleSqrt>(); + _sqrt->dtype(loco::DataType::S32); + _sqrt->name("sqrt"); + } + +protected: + luci::CircleSqrt *_sqrt = nullptr; +}; + +class SqrtGraph : public TestIOGraph, public SqrtGraphlet +{ +public: + SqrtGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + SqrtGraphlet::init(g(), shape); + + _sqrt->x(input()); + + output()->from(_sqrt); + } +}; + +} // namespace + +TEST(PartitionPModulesTest, simple_convert) +{ + luci::Module module; + + SqrtGraph g; + g.init({3, 3}); + g.transfer_to(&module); + + luci::PartitionTable pt; + pt.default_group = "A"; + + auto pgs = produce_pgroups(&module, pt); + auto pms = produce_pmodules(pgs.get()); + + ASSERT_EQ(1, pms.pmodules.size()); +} diff --git a/compiler/luci/partition/src/PartitionPModulesDump.cpp b/compiler/luci/partition/src/PartitionPModulesDump.cpp new file mode 100644 index 000000000..ee50bc6fb --- /dev/null +++ b/compiler/luci/partition/src/PartitionPModulesDump.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PartitionPModulesDump.h" + +#include "luci/LogHelper.h" + +#include <iostream> + +namespace luci +{ + +void dump(std::ostream &os, const PartedModule *pmodule) +{ + os << "--- PartedModule: " << pmodule->group << std::endl; + os << luci::fmt(pmodule->module->graph()); +} + +void dump(std::ostream &os, const PartedModules *pmodules) +{ + for (auto &pmodule : pmodules->pmodules) + { + dump(os, &pmodule); + } + os << std::endl; +} + +} // namespace luci + +std::ostream &operator<<(std::ostream &os, const luci::PartedModules *pmodules) +{ + luci::dump(os, pmodules); + return os; +} diff --git a/compiler/luci/partition/src/PartitionPModulesDump.h b/compiler/luci/partition/src/PartitionPModulesDump.h new file mode 100644 index 000000000..e77b235f4 --- /dev/null +++ b/compiler/luci/partition/src/PartitionPModulesDump.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PARTITION_PMODULES_DUMP_H__ +#define __LUCI_PARTITION_PMODULES_DUMP_H__ + +#include "luci/Partition.h" + +#include <iostream> + +namespace luci +{ + +void dump(std::ostream &os, const PartedModule *pmodule); +void dump(std::ostream &os, const PartedModules *pmodules); + +} // namespace luci + +std::ostream &operator<<(std::ostream &os, const luci::PartedModules *pmodules); + +#endif // __LUCI_PARTITION_PMODULES_DUMP_H__ diff --git a/compiler/luci/pass/CMakeLists.txt b/compiler/luci/pass/CMakeLists.txt index 2c5fb3407..2977fbed7 100644 --- a/compiler/luci/pass/CMakeLists.txt +++ b/compiler/luci/pass/CMakeLists.txt @@ -12,6 +12,7 @@ target_link_libraries(luci_pass PRIVATE luci_lang) target_link_libraries(luci_pass PRIVATE luci_log) target_link_libraries(luci_pass PRIVATE luci_service) target_link_libraries(luci_pass PRIVATE luci_logex) +target_link_libraries(luci_pass PRIVATE luci_profile) target_link_libraries(luci_pass PRIVATE nncc_common) target_link_libraries(luci_pass PRIVATE oops) install(TARGETS luci_pass DESTINATION lib) @@ -26,4 +27,5 @@ GTest_AddTest(luci_pass_test ${TESTS}) target_include_directories(luci_pass_test PRIVATE src) target_link_libraries(luci_pass_test luci_pass) target_link_libraries(luci_pass_test luci_lang) +target_link_libraries(luci_pass_test luci_testhelper) #target_link_libraries(luci_pass_test oops) diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index 906760e0a..1f5e1c8b9 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -35,6 +35,8 @@ public: enum Algorithm { FuseAddWithTConv, + FuseBatchNormWithConv, + FuseBatchNormWithDwConv, FuseBatchNormWithTConv, FuseBCQ, FuseInstanceNorm, @@ -44,7 +46,11 @@ public: QuantizeDequantizeWeights, QuantizeWithMinMax, Requantize, + FoldAddV2, + FoldCast, FoldDequantize, + FoldSparseToDense, + ForwardReshapeToUnaryOp, SparsifyTensorPass, FusePreActivationBatchNorm, MakeBatchNormGammaPositive, @@ -53,6 +59,15 @@ public: RemoveRedundantTranspose, ReplaceMulAddWithDepthwiseConv, SubstitutePackToReshape, + SubstituteSqueezeToReshape, + ConvertNCHWToNHWC, + RemoveUnnecessarySlice, + RemoveUnnecessaryStridedSlice, + RemoveUnnecessarySplit, + RemoveUnnecessaryReshape, + TransformMinMaxToRelu6Pass, + SubstituteTransposeToReshape, + RemoveRedundantReshape, }; enum AlgorithmParameters @@ -68,6 +83,10 @@ public: Sparsify_format, Sparsify_block_size, Sparsify_block_map, + + // convert NCHW to NHWC + NCHW_to_NHWC_preserve_input_shape, + NCHW_to_NHWC_preserve_output_shape, }; virtual ~Options() = default; diff --git a/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h b/compiler/luci/pass/include/luci/Pass/CircleShapeInferencePass.h index e21ab4cce..21d6d09d6 100644 --- a/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h +++ b/compiler/luci/pass/include/luci/Pass/CircleShapeInferencePass.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef __LUCI_SHAPE_INFERENCE_PASS_H__ -#define __LUCI_SHAPE_INFERENCE_PASS_H__ +#ifndef __LUCI_CIRCLE_SHAPE_INFERENCE_PASS_H__ +#define __LUCI_CIRCLE_SHAPE_INFERENCE_PASS_H__ #include <loco.h> @@ -25,12 +25,12 @@ namespace luci { /** - * @brief Pass to infer shape of nodes + * @brief Pass to infer shape of circle nodes */ -class ShapeInferencePass : public luci::Pass +class CircleShapeInferencePass : public luci::Pass { public: - virtual const char *name(void) const { return "luci::ShapeInferencePass"; } + virtual const char *name(void) const { return "luci::CircleShapeInferencePass"; } public: bool run(luci::Module *m); @@ -39,4 +39,4 @@ public: } // namespace luci -#endif //__LUCI_SHAPE_INFERENCE_PASS_H__ +#endif //__LUCI_CIRCLE_SHAPE_INFERENCE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/ConvertNCHWToNHWCPass.h b/compiler/luci/pass/include/luci/Pass/ConvertNCHWToNHWCPass.h new file mode 100644 index 000000000..ba2392596 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/ConvertNCHWToNHWCPass.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_CONVERT_NCHW_TO_NHWC_PASS_H__ +#define __LUCI_CONVERT_NCHW_TO_NHWC_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to convert NCHW Ops to NHWC + * + * @details Find operators that use NCHW layout and make them use NHWC. + * Strictly speaking, it is impossible to distinguish whether + * an operator is using NCHW or NHWC without programmers' annotations. + * But we guess the data layout of each operator as much as possible + * based on the assumptions described in the comments. + * Note that this Pass does not change the execution result even + * for the false-positive cases. + */ +struct ConvertNCHWToNHWCPass final : public logo::Pass +{ +public: + ConvertNCHWToNHWCPass(bool preserve_input, bool preserve_output) + : _preserve_input(preserve_input), _preserve_output(preserve_output) + { + // Do nothing + } + + ConvertNCHWToNHWCPass() = delete; + + virtual ~ConvertNCHWToNHWCPass() = default; + + const char *name(void) const final { return "luci::ConvertNCHWToNHWCPass"; } + + bool run(loco::Graph *g) final; + +private: + bool _preserve_input = false; + bool _preserve_output = false; +}; + +} // namespace luci + +#endif // __LUCI_CONVERT_NCHW_TO_NHWC_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FoldAddV2Pass.h b/compiler/luci/pass/include/luci/Pass/FoldAddV2Pass.h new file mode 100644 index 000000000..cd260b916 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FoldAddV2Pass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FOLD_ADD_V2_PASS_H__ +#define __LUCI_FOLD_ADD_V2_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to fold AddV2 to a constant tensor + * + */ +struct FoldAddV2Pass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FoldAddV2Pass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FOLD_ADD_V2_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FoldCastPass.h b/compiler/luci/pass/include/luci/Pass/FoldCastPass.h new file mode 100644 index 000000000..5d7ce4ad3 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FoldCastPass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FOLD_CAST_PASS_H__ +#define __LUCI_FOLD_CAST_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to fold Cast to a constant tensor + * + */ +struct FoldCastPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FoldCastPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FOLD_CAST_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FoldSparseToDensePass.h b/compiler/luci/pass/include/luci/Pass/FoldSparseToDensePass.h new file mode 100644 index 000000000..00d2447a5 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FoldSparseToDensePass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FOLD_SPARSE_TO_DENSE_PASS_H__ +#define __LUCI_FOLD_SPARSE_TO_DENSE_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to fold SparseToDense to a constant tensor + * + */ +struct FoldSparseToDensePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FoldSparseToDensePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FOLD_SPARSE_TO_DENSE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/ForwardReshapeToUnaryOpPass.h b/compiler/luci/pass/include/luci/Pass/ForwardReshapeToUnaryOpPass.h new file mode 100644 index 000000000..4c308e531 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/ForwardReshapeToUnaryOpPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FORWARD_RESHAPE_TO_UNARYOP_PASS_H__ +#define __LUCI_FORWARD_RESHAPE_TO_UNARYOP_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to Forward send Reshape after UnaryOp. + */ +struct ForwardReshapeToUnaryOpPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::ForwardReshapeToUnaryOpPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FORWARD_RESHAPE_TO_UNARYOP_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithConvPass.h b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithConvPass.h new file mode 100644 index 000000000..1ed85447b --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithConvPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_BATCH_NORM_WITH_CONV_PASS_H__ +#define __LUCI_FUSE_BATCH_NORM_WITH_CONV_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to fuse Batch Normalization into CircleConv + */ +struct FuseBatchNormWithConvPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseBatchNormWithConvPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_BATCH_NORM_WITH_CONV_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithDwConvPass.h b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithDwConvPass.h new file mode 100644 index 000000000..32885c6b2 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithDwConvPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_BATCH_NORM_WITH_DWCONV_PASS_H__ +#define __LUCI_FUSE_BATCH_NORM_WITH_DWCONV_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to fuse Batch Normalization into CircleDepthWiseConv2D + */ +struct FuseBatchNormWithDwConvPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseBatchNormWithDwConvPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_BATCH_NORM_WITH_DWCONV_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConv.h b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConvPass.h index d3e930a36..d3e930a36 100644 --- a/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConv.h +++ b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConvPass.h diff --git a/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h b/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h deleted file mode 100644 index c0ebc4e5d..000000000 --- a/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__ -#define __LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__ - -#include <loco.h> - -#include <luci/ModulePass.h> - -namespace luci -{ - -/** - * @brief Pass to copy shape/dtype of loco to circle node - * - * CAUTION : This pass will be removed after refactoring is finished - */ -class MigrateLegacyShapeDtypePass : public luci::Pass -{ -public: - virtual const char *name(void) const { return "luci::MigrateLegacyShapeDtypePass"; } - -public: - bool run(luci::Module *m); - bool run(loco::Graph *graph); -}; - -} // namespace luci - -#endif //__LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h index 713b88f9d..78e7323f9 100644 --- a/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h +++ b/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h @@ -34,7 +34,7 @@ class QuantizeDequantizeWeightsPass : public logo::Pass public: QuantizeDequantizeWeightsPass(loco::DataType input_dtype, loco::DataType output_dtype, QuantizationGranularity granularity) - : _input_dtype{input_dtype}, _output_dtype{output_dtype}, _granularity{granularity} + : _input_dtype{input_dtype}, _output_dtype{output_dtype}, _granularity{granularity} { // DO NOTHING } diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h index bb0d0ff40..9520910d5 100644 --- a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h +++ b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h @@ -34,7 +34,7 @@ class QuantizeWithMinMaxPass : public logo::Pass public: QuantizeWithMinMaxPass(loco::DataType input_dtype, loco::DataType output_dtype, QuantizationGranularity granularity) - : _input_dtype{input_dtype}, _output_dtype{output_dtype}, _granularity{granularity} + : _input_dtype{input_dtype}, _output_dtype{output_dtype}, _granularity{granularity} { // DO NOTHING } diff --git a/compiler/luci/pass/include/luci/Pass/RemoveRedundantReshapePass.h b/compiler/luci/pass/include/luci/Pass/RemoveRedundantReshapePass.h new file mode 100644 index 000000000..458ffc094 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveRedundantReshapePass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_REDUNDANT_RESHAPE_PASS_H__ +#define __LUCI_REMOVE_REDUNDANT_RESHAPE_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to remove redundant Reshape node into 1 Reshape node. + * @details This class will update consecutive two Reshape node into single Reshape node. + * As Reshape operation just change shape, not buffer, former reshape could be unnecessary. + */ +struct RemoveRedundantReshapePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveRedundantReshapePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_REDUNDANT_RESHAPE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapePass.h b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapePass.h new file mode 100644 index 000000000..8fca35e5b --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapePass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_UNNECESSARY_RESHAPE_PASS_H__ +#define __LUCI_REMOVE_UNNECESSARY_RESHAPE_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to Remove Unnecessary(input shape and output shape same) Reshape node. + */ +struct RemoveUnnecessaryReshapePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveUnnecessaryReshapePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_UNNECESSARY_RESHAPE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySlicePass.h b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySlicePass.h new file mode 100644 index 000000000..a3b0f2f8c --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySlicePass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_NO_EFFECT_SLICE_PASS_H__ +#define __LUCI_REMOVE_NO_EFFECT_SLICE_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to Remove Unnecessary(input and output are same) Slice node. + */ +struct RemoveUnnecessarySlicePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveUnnecessarySlicePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_NO_EFFECT_SLICE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySplitPass.h b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySplitPass.h new file mode 100644 index 000000000..0d9330fe7 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySplitPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_UNNECESSARY_SPLIT_PASS_H__ +#define __LUCI_REMOVE_UNNECESSARY_SPLIT_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Remove unnecessary Split OP + */ +struct RemoveUnnecessarySplitPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveUnnecessarySplitPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_UNNECESSARY_SPLIT_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryStridedSlicePass.h b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryStridedSlicePass.h new file mode 100644 index 000000000..0f6a61d43 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryStridedSlicePass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_UNNECESSARY_STRIDED_SLICE_PASS_H__ +#define __LUCI_REMOVE_UNNECESSARY_STRIDED_SLICE_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to Remove Unnecessary(input and output are same) StridedSlice node. + */ +struct RemoveUnnecessaryStridedSlicePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveUnnecessaryStridedSlicePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_UNNECESSARY_STRIDED_SLICE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/RequantizePass.h b/compiler/luci/pass/include/luci/Pass/RequantizePass.h index 2442b24ea..c6c424f1b 100644 --- a/compiler/luci/pass/include/luci/Pass/RequantizePass.h +++ b/compiler/luci/pass/include/luci/Pass/RequantizePass.h @@ -33,7 +33,7 @@ class RequantizePass : public logo::Pass { public: RequantizePass(loco::DataType input_dtype, loco::DataType output_dtype) - : _input_dtype{input_dtype}, _output_dtype{output_dtype} + : _input_dtype{input_dtype}, _output_dtype{output_dtype} { // DO NOTHING } diff --git a/compiler/luci/pass/include/luci/Pass/SparsifyTensorPass.h b/compiler/luci/pass/include/luci/Pass/SparsifyTensorPass.h index 41f43bf88..0ce142c55 100644 --- a/compiler/luci/pass/include/luci/Pass/SparsifyTensorPass.h +++ b/compiler/luci/pass/include/luci/Pass/SparsifyTensorPass.h @@ -35,8 +35,8 @@ public: SparsifyTensorPass(const std::string &tensor_name, const std::vector<int32_t> &traversal_order, const std::vector<DimensionType> &format, const std::vector<int32_t> &block_size, const std::vector<int32_t> &block_map) - : _tensor_name{tensor_name}, _traversal_order{traversal_order}, _format{format}, - _block_size{block_size}, _block_map{block_map} + : _tensor_name{tensor_name}, _traversal_order{traversal_order}, _format{format}, + _block_size{block_size}, _block_map{block_map} { // DO NOTHING } diff --git a/compiler/luci/pass/include/luci/Pass/SubstituteSqueezeToReshapePass.h b/compiler/luci/pass/include/luci/Pass/SubstituteSqueezeToReshapePass.h new file mode 100644 index 000000000..d8df6ac3f --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/SubstituteSqueezeToReshapePass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_SUBSTITUTE_SQUEEZE_TO_RESHAPE_PASS_H__ +#define __LUCI_SUBSTITUTE_SQUEEZE_TO_RESHAPE_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to Substitute Squeeze to Reshape node for certain conditions. + */ +struct SubstituteSqueezeToReshapePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::SubstituteSqueezeToReshapePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_SUBSTITUTE_SQUEEZE_TO_RESHAPE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/SubstituteTransposeToReshapePass.h b/compiler/luci/pass/include/luci/Pass/SubstituteTransposeToReshapePass.h new file mode 100644 index 000000000..ee708585a --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/SubstituteTransposeToReshapePass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_SUBSTITUTE_TRANSPOSE_TO_RESHAPE_PASS_H__ +#define __LUCI_SUBSTITUTE_TRANSPOSE_TO_RESHAPE_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to Substitute Transpose with certain input shape condition to single reshape node. + */ +struct SubstituteTransposeToReshapePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::SubstituteTransposeToReshapePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_SUBSTITUTE_TRANSPOSE_TO_RESHAPE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/TransformMinMaxToRelu6Pass.h b/compiler/luci/pass/include/luci/Pass/TransformMinMaxToRelu6Pass.h new file mode 100644 index 000000000..9ea39ee4e --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/TransformMinMaxToRelu6Pass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_TRANSFORM_MIN_MAX_TO_RELU6_PASS_H__ +#define __LUCI_TRANSFORM_MIN_MAX_TO_RELU6_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to transform Maximum(Minimum(input, 6), 0) to Relu6 + */ +struct TransformMinMaxToRelu6Pass final : public logo::Pass +{ + const char *name(void) const final { return "luci::TransformMinMaxToRelu6Pass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_TRANSFORM_MIN_MAX_TO_RELU6_PASS_H__ diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.cpp new file mode 100644 index 000000000..c1a06bfda --- /dev/null +++ b/compiler/luci/pass/src/BatchNormPatternFinder.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "BatchNormPatternFinder.h" + +#include <luci/IR/CircleNodes.h> + +namespace luci +{ + +bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::CircleConst *&beta) +{ + auto x = loco::must_cast<luci::CircleNode *>(add->x()); + auto y = loco::must_cast<luci::CircleNode *>(add->y()); + + luci::CircleMul *pred = nullptr; + luci::CircleConst *constant = nullptr; + + if (x->opcode() == luci::CircleOpcode::CIRCLECONST && y->opcode() == luci::CircleOpcode::MUL) + { + pred = loco::must_cast<luci::CircleMul *>(y); + constant = loco::must_cast<luci::CircleConst *>(x); + } + else if (x->opcode() == luci::CircleOpcode::MUL && y->opcode() == luci::CircleOpcode::CIRCLECONST) + { + pred = loco::must_cast<luci::CircleMul *>(x); + constant = loco::must_cast<luci::CircleConst *>(y); + } + else + { + return false; + } + + if (constant->rank() != 1) + return false; + + auto channel_dim = constant->dim(0); + // Assumption: Layout is channel-last + if (!(channel_dim == add->dim(add->rank() - 1))) + return false; + + mul = pred; + beta = constant; + return true; +} + +bool is_batchnorm_add(const luci::CircleAdd *add) +{ + // for dummy mul and beta + luci::CircleMul *mul = nullptr; + luci::CircleConst *beta = nullptr; + + return is_batchnorm_add(add, mul, beta); +} + +bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node, + luci::CircleConst *&gamma) +{ + auto x = dynamic_cast<luci::CircleConst *>(mul->x()); + auto y = dynamic_cast<luci::CircleConst *>(mul->y()); + + luci::CircleNode *pred = nullptr; + luci::CircleConst *constant = nullptr; + + if (x != nullptr && y == nullptr) + { + pred = loco::must_cast<luci::CircleNode *>(mul->y()); + constant = x; + } + else if (x == nullptr && y != nullptr) + { + pred = loco::must_cast<luci::CircleNode *>(mul->x()); + constant = y; + } + else + { + return false; + } + + if (constant->rank() != 1) + return false; + + auto channel_dim = constant->dim(0); + // Assumption: Layout is channel-last + if (!(channel_dim == mul->dim(mul->rank() - 1))) + return false; + + pred_node = pred; + gamma = constant; + return true; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.h b/compiler/luci/pass/src/BatchNormPatternFinder.h new file mode 100644 index 000000000..58cdbb464 --- /dev/null +++ b/compiler/luci/pass/src/BatchNormPatternFinder.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PASS_BATCH_NORM_PATTERN_FINDER_H__ +#define __LUCI_PASS_BATCH_NORM_PATTERN_FINDER_H__ + +#include <luci/IR/CircleNodes.h> + +namespace luci +{ + +/** + * @brief Find Mul-Add pattern and return Mul and beta as BatchNorm + */ +bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::CircleConst *&beta); + +/** + * @brief Find Mul-Add pattern + */ +bool is_batchnorm_add(const luci::CircleAdd *add); + +/** + * @brief Find Const-Mul pattern and return Node and gamma as BatchNorm + */ +bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node, + luci::CircleConst *&gamma); + +} // namespace luci + +#endif // __LUCI_PASS_BATCH_NORM_PATTERN_FINDER_H__ diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp new file mode 100644 index 000000000..08e7fac1c --- /dev/null +++ b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp @@ -0,0 +1,217 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "BatchNormPatternFinder.h" + +#include <luci/test/TestIOGraph.h> + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace luci +{ +namespace test +{ + +/** + * @brief Graphlet with Add and Const as beta from BatchNorm + */ +class AddBetaGraphlet +{ +public: + AddBetaGraphlet() = default; + + void init(loco::Graph *g, const ShapeU32 shape, luci::FusedActFunc actf) + { + _add = g->nodes()->create<luci::CircleAdd>(); + _add_beta = g->nodes()->create<luci::CircleConst>(); + + _add->dtype(loco::DataType::FLOAT32); + _add_beta->dtype(loco::DataType::FLOAT32); + + _add->fusedActivationFunction(actf); + + assert(shape.size() > 0); + auto last_it = std::prev(shape.end(), 1); + auto channel_size = *last_it; + + _add->shape(shape); + _add_beta->shape({channel_size}); + _add_beta->size<loco::DataType::FLOAT32>(channel_size); + for (uint32_t i = 0; i < channel_size; i++) + _add_beta->at<loco::DataType::FLOAT32>(i) = i; + + _add->name("add"); + _add_beta->name("add_beta"); + } + +public: + luci::CircleAdd *add() { return _add; } + +protected: + luci::CircleAdd *_add = nullptr; + luci::CircleConst *_add_beta = nullptr; +}; + +/** + * @brief Graphlet with Mul and Const as gamma from BatchNorm + */ +class MulGammaGraphlet +{ +public: + MulGammaGraphlet() = default; + + void init(loco::Graph *g, const ShapeU32 shape, luci::FusedActFunc actf) + { + _mul = g->nodes()->create<luci::CircleMul>(); + _mul_gamma = g->nodes()->create<luci::CircleConst>(); + + _mul->dtype(loco::DataType::FLOAT32); + _mul_gamma->dtype(loco::DataType::FLOAT32); + + _mul->fusedActivationFunction(actf); + + assert(shape.size() > 0); + auto last_it = std::prev(shape.end(), 1); + auto channel_size = *last_it; + + _mul->shape(shape); + _mul_gamma->shape({channel_size}); + _mul_gamma->size<loco::DataType::FLOAT32>(channel_size); + for (uint32_t i = 0; i < channel_size; i++) + _mul_gamma->at<loco::DataType::FLOAT32>(i) = i; + + _mul->name("mul"); + _mul_gamma->name("mul_gamma"); + } + +public: + luci::CircleMul *mul(void) { return _mul; } + +protected: + luci::CircleMul *_mul = nullptr; + luci::CircleConst *_mul_gamma = nullptr; +}; + +/** + * @brief Graph of Mul-Add pattern from BatchNorm + */ +class MulAddGraph : public TestIOGraph, public AddBetaGraphlet, public MulGammaGraphlet +{ +public: + MulAddGraph() = default; + + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + MulGammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE); + AddBetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU); + + // connect network + _mul->x(input()); + _mul->y(_mul_gamma); + _add->x(_mul); + _add->y(_add_beta); + output()->from(_add); + } +}; + +/** + * @brief Graph of Add with Const + */ +class AddGraph : public TestIOGraph, public AddBetaGraphlet +{ +public: + AddGraph() = default; + + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + AddBetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU); + + // connect network + _add->x(input()); + _add->y(_add_beta); + output()->from(_add); + } +}; + +} // namespace test +} // namespace luci + +class BatchNormPatternFinderMulAddTest : public ::testing::Test +{ +public: + BatchNormPatternFinderMulAddTest() = default; + +protected: + luci::test::MulAddGraph _mag; +}; + +class BatchNormPatternFinderAddTest : public ::testing::Test +{ +public: + BatchNormPatternFinderAddTest() = default; + +protected: + luci::test::AddGraph _ag; +}; + +TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add) +{ + _mag.init({1, 16, 16, 4}, {1, 16, 16, 4}); + + luci::CircleMul *mul = nullptr; + luci::CircleConst *beta = nullptr; + + auto res = luci::is_batchnorm_add(_mag.add(), mul, beta); + ASSERT_TRUE(res); + ASSERT_NE(nullptr, mul); + ASSERT_NE(nullptr, beta); +} + +TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add2) +{ + _mag.init({1, 16, 16, 4}, {1, 16, 16, 4}); + + auto res = luci::is_batchnorm_add(_mag.add()); + ASSERT_TRUE(res); +} + +TEST_F(BatchNormPatternFinderAddTest, is_batchnorm_add_NEG) +{ + _ag.init({1, 16, 16, 4}, {1, 16, 16, 4}); + + luci::CircleMul *mul = nullptr; + luci::CircleConst *beta = nullptr; + + auto res = luci::is_batchnorm_add(_ag.add(), mul, beta); + ASSERT_FALSE(res); +} + +TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul) +{ + _mag.init({1, 16, 16, 4}, {1, 16, 16, 4}); + + luci::CircleNode *pred = nullptr; + luci::CircleConst *gamma = nullptr; + + auto res = luci::is_batchnorm_mul(_mag.mul(), pred, gamma); + ASSERT_TRUE(res); + ASSERT_NE(nullptr, pred); + ASSERT_NE(nullptr, gamma); +} diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index cc9fe481c..bddad34fa 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -16,16 +16,28 @@ #include "luci/CircleOptimizer.h" +#include "luci/Pass/ConvertNCHWToNHWCPass.h" +#include "luci/Pass/FoldAddV2Pass.h" +#include "luci/Pass/FoldCastPass.h" #include "luci/Pass/FoldDequantizePass.h" +#include "luci/Pass/FoldSparseToDensePass.h" +#include "luci/Pass/ForwardReshapeToUnaryOpPass.h" #include "luci/Pass/FuseActivationFunctionPass.h" #include "luci/Pass/FuseAddWithTConvPass.h" -#include "luci/Pass/FuseBatchNormWithTConv.h" +#include "luci/Pass/FuseBatchNormWithConvPass.h" +#include "luci/Pass/FuseBatchNormWithDwConvPass.h" +#include "luci/Pass/FuseBatchNormWithTConvPass.h" #include "luci/Pass/FuseBCQPass.h" #include "luci/Pass/FuseInstanceNormPass.h" #include "luci/Pass/FusePreActivationBatchNormPass.h" #include "luci/Pass/MakeBatchNormGammaPositivePass.h" #include "luci/Pass/PropagateQuantParamPass.h" +#include "luci/Pass/RemoveRedundantReshapePass.h" #include "luci/Pass/RemoveRedundantTransposePass.h" +#include "luci/Pass/RemoveUnnecessaryReshapePass.h" +#include "luci/Pass/RemoveUnnecessarySlicePass.h" +#include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h" +#include "luci/Pass/RemoveUnnecessarySplitPass.h" #include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h" #include "luci/Pass/ResolveCustomOpAddPass.h" #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h" @@ -36,21 +48,22 @@ #include "luci/Pass/SparsifyTensorPass.h" #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h" #include "luci/Pass/SubstitutePackToReshapePass.h" +#include "luci/Pass/SubstituteSqueezeToReshapePass.h" +#include "luci/Pass/SubstituteTransposeToReshapePass.h" +#include "luci/Pass/TransformMinMaxToRelu6Pass.h" // TODO add more passes -#include "luci/Pass/ShapeInferencePass.h" -#include "luci/Pass/ShapeSignatureInferencePass.h" -#include "luci/Pass/TypeInferencePass.h" - -// Following passes will be removed after refactoring is finished -#include "luci/Pass/MigrateLegacyShapeDtypePass.h" +#include "luci/Pass/CircleShapeInferencePass.h" +#include "luci/Pass/CircleTypeInferencePass.h" // logo passes #include <logo/RemoveDeadNodeWithQueryPass.h> #include "ModulePhase.h" #include "ProgressReporter.h" -#include "CircleOptimizerUtils.h" +#include "helpers/Strings.h" + +#include "QuantizedModelVerifier.h" #include <luci/IR/CircleNodes.h> #include <logo/Phase.h> @@ -61,20 +74,6 @@ namespace { -std::vector<int> parseIntFromCommadelimitedStr(std::string str) -{ - std::vector<int> ret; - std::istringstream is(str); - for (uint32_t i; is >> i;) - { - assert(i != ','); - ret.push_back(i); - if (is.peek() == ',') - is.ignore(); - } - return ret; -} - using namespace luci; class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options @@ -138,13 +137,9 @@ void CircleOptimizer::optimize(luci::Module *m) const { luci::Phase phase; - // Following passes will be deprecated after refactoring is finished. - phase.emplace_back(std::make_unique<luci::MigrateLegacyShapeDtypePass>()); - // Following passes are needed everytime when other passes create new node or modify some nodes. - phase.emplace_back(std::make_unique<luci::ShapeInferencePass>()); - phase.emplace_back(std::make_unique<luci::ShapeSignatureInferencePass>()); - phase.emplace_back(std::make_unique<luci::TypeInferencePass>()); + phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); + phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); if (_options->query(Options::Algorithm::FuseBCQ)) { @@ -164,13 +159,9 @@ void CircleOptimizer::optimize(loco::Graph *g) const /* TRANSFORM DECLARATION BEGIN */ phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>()); - // Following passes will be deprecated after refactoring is finished. - phase.emplace_back(std::make_unique<luci::MigrateLegacyShapeDtypePass>()); - // Following passes are needed everytime when other passes create new node or modify some nodes. - phase.emplace_back(std::make_unique<luci::TypeInferencePass>()); - phase.emplace_back(std::make_unique<luci::ShapeInferencePass>()); - phase.emplace_back(std::make_unique<luci::ShapeSignatureInferencePass>()); + phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); + phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); if (_options->query(Options::Algorithm::ResolveCustomOpAdd)) { @@ -188,6 +179,14 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<FuseInstanceNormPass>()); } + if (_options->query(Options::Algorithm::FuseBatchNormWithConv)) + { + phase.emplace_back(std::make_unique<FuseBatchNormWithConvPass>()); + } + if (_options->query(Options::Algorithm::FuseBatchNormWithDwConv)) + { + phase.emplace_back(std::make_unique<FuseBatchNormWithDwConvPass>()); + } if (_options->query(Options::Algorithm::FuseBatchNormWithTConv)) { phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>()); @@ -200,10 +199,26 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<FuseActivationFunctionPass>()); } + if (_options->query(Options::Algorithm::FoldAddV2)) + { + phase.emplace_back(std::make_unique<luci::FoldAddV2Pass>()); + } + if (_options->query(Options::Algorithm::FoldCast)) + { + phase.emplace_back(std::make_unique<luci::FoldCastPass>()); + } if (_options->query(Options::Algorithm::FoldDequantize)) { phase.emplace_back(std::make_unique<luci::FoldDequantizePass>()); } + if (_options->query(Options::Algorithm::FoldSparseToDense)) + { + phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>()); + } + if (_options->query(Options::Algorithm::ForwardReshapeToUnaryOp)) + { + phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>()); + } if (_options->query(Options::Algorithm::FusePreActivationBatchNorm)) { phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>()); @@ -216,6 +231,26 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::ShuffleWeightTo16x1Float32Pass>()); } + if (_options->query(Options::Algorithm::RemoveUnnecessaryReshape)) + { + phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapePass>()); + } + if (_options->query(Options::Algorithm::RemoveUnnecessarySlice)) + { + phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySlicePass>()); + } + if (_options->query(Options::Algorithm::RemoveUnnecessaryStridedSlice)) + { + phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryStridedSlicePass>()); + } + if (_options->query(Options::Algorithm::RemoveUnnecessarySplit)) + { + phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySplitPass>()); + } + if (_options->query(Options::Algorithm::RemoveRedundantReshape)) + { + phase.emplace_back(std::make_unique<luci::RemoveRedundantReshapePass>()); + } if (_options->query(Options::Algorithm::RemoveRedundantTranspose)) { phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>()); @@ -228,6 +263,28 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::SubstitutePackToReshapePass>()); } + if (_options->query(Options::Algorithm::SubstituteSqueezeToReshape)) + { + phase.emplace_back(std::make_unique<luci::SubstituteSqueezeToReshapePass>()); + } + if (_options->query(Options::Algorithm::SubstituteTransposeToReshape)) + { + phase.emplace_back(std::make_unique<luci::SubstituteTransposeToReshapePass>()); + } + if (_options->query(Options::Algorithm::TransformMinMaxToRelu6Pass)) + { + phase.emplace_back(std::make_unique<luci::TransformMinMaxToRelu6Pass>()); + } + if (_options->query(Options::Algorithm::ConvertNCHWToNHWC)) + { + bool preserve_input = + _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_preserve_input_shape) == "true"; + bool preserve_output = + _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_preserve_output_shape) == "true"; + + phase.emplace_back( + std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output)); + } /* TRANSFORM DECLARATION END */ @@ -275,7 +332,7 @@ void CircleOptimizer::quantize(loco::Graph *g) const } luci::QuantizeDequantizeWeightsPass fake_quantizer( - str_to_dtype(input_dtype), str_to_dtype(output_dtype), str_to_granularity(granularity)); + str_to_dtype(input_dtype), str_to_dtype(output_dtype), str_to_granularity(granularity)); fake_quantizer.run(g); } @@ -315,14 +372,19 @@ void CircleOptimizer::quantize(loco::Graph *g) const phase.emplace_back(std::make_unique<luci::PropagateQuantParamPass>()); - phase.emplace_back(std::make_unique<luci::ShapeInferencePass>()); - phase.emplace_back(std::make_unique<luci::TypeInferencePass>()); + phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); + phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>()); ProgressReporter prog(g, logo::PhaseStrategy::Saturate); logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g}; phase_runner.attach(&prog); phase_runner.run(phase); + + // Verify the type/granularity of the quantized model + luci::QuantizedModelVerifier verifier(str_to_dtype(output_dtype), + str_to_granularity(granularity)); + verifier.verify(g); } // Requantize @@ -349,8 +411,8 @@ void CircleOptimizer::quantize(loco::Graph *g) const logo::Phase phase; // Do Shape/Type inference - phase.emplace_back(std::make_unique<luci::ShapeInferencePass>()); - phase.emplace_back(std::make_unique<luci::TypeInferencePass>()); + phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); + phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); ProgressReporter prog(g, logo::PhaseStrategy::Saturate); logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g}; @@ -364,13 +426,13 @@ void CircleOptimizer::sparsify(loco::Graph *g) const { std::string tensor_name = _options->param(Options::AlgorithmParameters::Sparsify_tensor_name); std::string str_tarversal_order = - _options->param(Options::AlgorithmParameters::Sparsify_traversal_order); + _options->param(Options::AlgorithmParameters::Sparsify_traversal_order); std::string str_format = _options->param(Options::AlgorithmParameters::Sparsify_format); std::string str_block_size = _options->param(Options::AlgorithmParameters::Sparsify_block_size); std::string str_block_map = _options->param(Options::AlgorithmParameters::Sparsify_block_map); // traversal order - std::vector<int32_t> traversal_order = parseIntFromCommadelimitedStr(str_tarversal_order); + std::vector<int32_t> traversal_order = csv_to_vector<int32_t>(str_tarversal_order); // format std::vector<DimensionType> format; std::istringstream is(str_format); @@ -385,9 +447,9 @@ void CircleOptimizer::sparsify(loco::Graph *g) const is.ignore(); } // block size - std::vector<int32_t> block_size = parseIntFromCommadelimitedStr(str_block_size); + std::vector<int32_t> block_size = csv_to_vector<int32_t>(str_block_size); // block map - std::vector<int32_t> block_map = parseIntFromCommadelimitedStr(str_block_map); + std::vector<int32_t> block_map = csv_to_vector<int32_t>(str_block_map); luci::SparsifyTensorPass sparsifier{tensor_name, traversal_order, format, block_size, block_map}; diff --git a/compiler/luci/pass/src/CircleOptimizer.test.cpp b/compiler/luci/pass/src/CircleOptimizer.test.cpp new file mode 100644 index 000000000..ca6dc77f3 --- /dev/null +++ b/compiler/luci/pass/src/CircleOptimizer.test.cpp @@ -0,0 +1,238 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/CircleOptimizer.h" + +#include <gtest/gtest.h> + +using namespace luci; +using Algorithms = luci::CircleOptimizer::Options::Algorithm; +using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters; + +TEST(CircleOptimizerTest, optimize_algorithms) +{ + loco::Graph g; + luci::CircleOptimizer o; + + auto options = o.options(); + + // NOTE these are added to cover the test + // TODO add more if needed + options->enable(Algorithms::FoldAddV2); + options->enable(Algorithms::FoldCast); + options->enable(Algorithms::FoldDequantize); + options->enable(Algorithms::FoldSparseToDense); + options->enable(Algorithms::FusePreActivationBatchNorm); + options->enable(Algorithms::MakeBatchNormGammaPositive); + options->enable(Algorithms::ShuffleWeightTo16x1Float32); + options->enable(Algorithms::RemoveUnnecessaryReshape); + options->enable(Algorithms::RemoveUnnecessarySlice); + options->enable(Algorithms::RemoveUnnecessarySplit); + options->enable(Algorithms::ReplaceMulAddWithDepthwiseConv); + options->enable(Algorithms::SubstituteTransposeToReshape); + options->enable(Algorithms::ConvertNCHWToNHWC); + + o.optimize(&g); + + SUCCEED(); +} + +TEST(CircleOptimizerTest, sparsify_simple) +{ + loco::Graph g; + luci::CircleOptimizer o; + + auto options = o.options(); + + options->enable(Algorithms::SparsifyTensorPass); + options->param(AlgorithmParameters::Sparsify_tensor_name, "dummy"); + options->param(AlgorithmParameters::Sparsify_traversal_order, "dummy"); + options->param(AlgorithmParameters::Sparsify_format, "ds"); + options->param(AlgorithmParameters::Sparsify_block_size, "1,1"); + options->param(AlgorithmParameters::Sparsify_block_map, "1,1"); + + o.sparsify(&g); + + SUCCEED(); +} + +TEST(CircleOptimizerTest, quantize_quantdequant_simple) +{ + loco::Graph g; + luci::CircleOptimizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeDequantizeWeights); + options->param(AlgorithmParameters::Quantize_input_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_dtype, "uint8"); + options->param(AlgorithmParameters::Quantize_granularity, "layer"); + + o.quantize(&g); + + SUCCEED(); +} + +TEST(CircleOptimizerTest, quantize_quantdequant_input_NEG) +{ + loco::Graph g; + luci::CircleOptimizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeDequantizeWeights); + options->param(AlgorithmParameters::Quantize_input_dtype, "invalid"); + options->param(AlgorithmParameters::Quantize_output_dtype, "uint8"); + options->param(AlgorithmParameters::Quantize_granularity, "layer"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleOptimizerTest, quantize_quantdequant_output_NEG) +{ + loco::Graph g; + luci::CircleOptimizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeDequantizeWeights); + options->param(AlgorithmParameters::Quantize_input_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_dtype, "invalid"); + options->param(AlgorithmParameters::Quantize_granularity, "layer"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleOptimizerTest, quantize_quantdequant_gran_NEG) +{ + loco::Graph g; + luci::CircleOptimizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeDequantizeWeights); + options->param(AlgorithmParameters::Quantize_input_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_dtype, "uint8"); + options->param(AlgorithmParameters::Quantize_granularity, "invalid"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleOptimizerTest, quantize_minmax_simple) +{ + loco::Graph g; + luci::CircleOptimizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeWithMinMax); + options->param(AlgorithmParameters::Quantize_input_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_dtype, "uint8"); + options->param(AlgorithmParameters::Quantize_granularity, "layer"); + + o.quantize(&g); + + SUCCEED(); +} + +TEST(CircleOptimizerTest, quantize_minmax_input_NEG) +{ + loco::Graph g; + luci::CircleOptimizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeWithMinMax); + options->param(AlgorithmParameters::Quantize_input_dtype, "invalid"); + options->param(AlgorithmParameters::Quantize_output_dtype, "uint8"); + options->param(AlgorithmParameters::Quantize_granularity, "layer"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleOptimizerTest, quantize_minmax_output_NEG) +{ + loco::Graph g; + luci::CircleOptimizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeWithMinMax); + options->param(AlgorithmParameters::Quantize_input_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_dtype, "invalid"); + options->param(AlgorithmParameters::Quantize_granularity, "layer"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleOptimizerTest, quantize_minmax_gran_NEG) +{ + loco::Graph g; + luci::CircleOptimizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeWithMinMax); + options->param(AlgorithmParameters::Quantize_input_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_dtype, "uint8"); + options->param(AlgorithmParameters::Quantize_granularity, "invalid"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleOptimizerTest, quantize_requant_simple) +{ + loco::Graph g; + luci::CircleOptimizer o; + + auto options = o.options(); + + options->enable(Algorithms::Requantize); + options->param(AlgorithmParameters::Quantize_input_dtype, "int8"); + options->param(AlgorithmParameters::Quantize_output_dtype, "uint8"); + + o.quantize(&g); + + SUCCEED(); +} + +TEST(CircleOptimizerTest, quantize_requant_input_NEG) +{ + loco::Graph g; + luci::CircleOptimizer o; + + auto options = o.options(); + + options->enable(Algorithms::Requantize); + options->param(AlgorithmParameters::Quantize_input_dtype, "invalid"); + options->param(AlgorithmParameters::Quantize_output_dtype, "uint8"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleOptimizerTest, quantize_requant_output_NEG) +{ + loco::Graph g; + luci::CircleOptimizer o; + + auto options = o.options(); + + options->enable(Algorithms::Requantize); + options->param(AlgorithmParameters::Quantize_input_dtype, "int8"); + options->param(AlgorithmParameters::Quantize_output_dtype, "invalid"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} diff --git a/compiler/luci/pass/src/CircleOptimizerUtils.cpp b/compiler/luci/pass/src/CircleOptimizerUtils.cpp index ffc372392..127573db4 100644 --- a/compiler/luci/pass/src/CircleOptimizerUtils.cpp +++ b/compiler/luci/pass/src/CircleOptimizerUtils.cpp @@ -16,74 +16,18 @@ #include "CircleOptimizerUtils.h" -namespace luci -{ - -bool in_array(const std::string &str, const std::vector<std::string> &array) -{ - return std::find(array.begin(), array.end(), str) != array.end(); -} +#include <luci/IR/CircleNode.h> -std::string to_string(const std::vector<std::string> &strings) -{ - assert(!strings.empty()); - - std::string res; - for (unsigned int i = 0; i < strings.size() - 1; i++) - res += strings[i] + ", "; - - res += strings[strings.size() - 1]; - return res; -} - -std::string to_lower_case(std::string s) -{ - std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); - return s; -} - -loco::DataType str_to_dtype(const std::string &str) +namespace luci { - if (to_lower_case(str).compare("uint8") == 0) - return loco::DataType::U8; - if (to_lower_case(str).compare("uint16") == 0) - return loco::DataType::U16; - if (to_lower_case(str).compare("uint32") == 0) - return loco::DataType::U32; - if (to_lower_case(str).compare("uint64") == 0) - return loco::DataType::U64; - - if (to_lower_case(str).compare("int8") == 0) - return loco::DataType::S8; - if (to_lower_case(str).compare("int16") == 0) - return loco::DataType::S16; - if (to_lower_case(str).compare("int32") == 0) - return loco::DataType::S32; - if (to_lower_case(str).compare("int64") == 0) - return loco::DataType::S64; - - if (to_lower_case(str).compare("float16") == 0) - return loco::DataType::FLOAT16; - if (to_lower_case(str).compare("float32") == 0) - return loco::DataType::FLOAT32; - if (to_lower_case(str).compare("float64") == 0) - return loco::DataType::FLOAT64; - if (to_lower_case(str).compare("bool") == 0) - return loco::DataType::BOOL; - - return loco::DataType::Unknown; -} - -QuantizationGranularity str_to_granularity(const std::string &str) +bool has_dynamic_shape(const loco::Node *node) { - if (to_lower_case(str).compare("layer") == 0) - return QuantizationGranularity::LayerWise; - - if (to_lower_case(str).compare("channel") == 0) - return QuantizationGranularity::ChannelWise; - - throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'"); + const auto circle_node = loco::must_cast<const luci::CircleNode *>(node); + for (uint32_t i = 0; i < circle_node->rank(); ++i) + if (!circle_node->dim(i).known()) + return true; + return false; } } // namespace luci diff --git a/compiler/luci/pass/src/CircleOptimizerUtils.h b/compiler/luci/pass/src/CircleOptimizerUtils.h index 7e577a05f..e04942bfa 100644 --- a/compiler/luci/pass/src/CircleOptimizerUtils.h +++ b/compiler/luci/pass/src/CircleOptimizerUtils.h @@ -17,25 +17,12 @@ #ifndef __LUCI_CIRCLE_OPTIMIZER_UTILS_H__ #define __LUCI_CIRCLE_OPTIMIZER_UTILS_H__ -#include "luci/Pass/QuantizeDequantizeWeightsPass.h" -#include "luci/Pass/QuantizeWithMinMaxPass.h" - #include <loco.h> -#include <algorithm> - namespace luci { -bool in_array(const std::string &, const std::vector<std::string> &); - -std::string to_string(const std::vector<std::string> &); - -std::string to_lower_case(std::string); - -loco::DataType str_to_dtype(const std::string &); - -QuantizationGranularity str_to_granularity(const std::string &); +bool has_dynamic_shape(const loco::Node *node); } // namespace luci diff --git a/compiler/luci/pass/src/CircleShapeInferencePass.cpp b/compiler/luci/pass/src/CircleShapeInferencePass.cpp new file mode 100644 index 000000000..ddab22421 --- /dev/null +++ b/compiler/luci/pass/src/CircleShapeInferencePass.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "helpers/InferenceCandidates.h" + +#include "luci/Pass/CircleShapeInferencePass.h" + +#include <luci/Service/CircleShapeInference.h> + +#include <loco.h> + +namespace +{ + +bool is_same_shape(luci::CircleNode *node, loco::TensorShape shape) +{ + if (node->shape_status() != luci::ShapeStatus::VALID) + return false; + + if (node->rank() != shape.rank()) + return false; + + for (uint32_t i = 0; i < node->rank(); ++i) + { + if (node->dim(i).known() != shape.dim(i).known()) + return false; + + if (node->dim(i).value() != shape.dim(i).value()) + return false; + } + + return true; +} + +} // namespace + +namespace luci +{ + +bool CircleShapeInferencePass::run(luci::Module *m) +{ + bool changed = false; + + for (size_t g = 0; g < m->size(); ++g) + { + if (run(m->graph(g))) + changed = true; + } + + return changed; +} + +bool CircleShapeInferencePass::run(loco::Graph *g) +{ + luci::sinf::Rule shape_infer_rule; + bool changed = false; + + for (auto node : inference_candidates(g)) + { + loco::TensorShape shape; + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + + if (shape_infer_rule.infer(circle_node, shape) && !is_same_shape(circle_node, shape)) + { + circle_node->rank(shape.rank()); + for (uint32_t i = 0; i < shape.rank(); ++i) + circle_node->dim(i) = shape.dim(i); + + circle_node->shape_status(luci::ShapeStatus::VALID); + + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/CircleShapeInferencePass.test.cpp b/compiler/luci/pass/src/CircleShapeInferencePass.test.cpp new file mode 100644 index 000000000..cb3f1fe5f --- /dev/null +++ b/compiler/luci/pass/src/CircleShapeInferencePass.test.cpp @@ -0,0 +1,364 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/CircleShapeInferencePass.h" + +#include <loco.h> + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +TEST(CircleShapeInferencePassTest, name) +{ + luci::CircleShapeInferencePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +/** + * This test is to check whether shape inference is done by topological order. + * + * When perm() of "transpose1" is changed from "old_perm" to "new_perm" + * by some of luci/Pass like below diagram, shape_status of "transpose1" is + * still VALID even the shape should be changed. + * If "transpose2" is visited first before shape of "transpose1" is updated, + * "transpose2" can reference the shape of "relu" which is not updated yet. + * Then shape of "transpose2" becomes 3x5x5x1 and it causes an error at "conv2d". + * + * <Initial graph> + * 4x1x1x3 + * [old_perm] ----------+ [filter] ----------+ + * (0,2,1,3) | | + * | [bias] ----------+ + * | | + * input ------> [transpose1] ------> [relu] ------> [conv2d] ------> output + * 1x5x5x3 1x5x5x3 1x5x5x3 1x5x5x4 + * + * + * <Right after transformation> + * 4x1x1x3 + * [new_perm] ----------+-----------------------------------+ [filter] ------+ + * (3,2,1,0) | | | + * | | [bias] ------+ + * | | | + * input ------> [transpose1] ------> [relu] ------> [transpose2] ------> [conv2d] ------> output + * 1x5x5x3 1x5x5x3 1x5x5x3 ? 1x5x5x4 + * + * + * <Expected result> + * 4x1x1x3 + * [new_perm] ----------+-----------------------------------+ [filter] ------+ + * (3,2,1,0) | | | + * | | [bias] ------+ + * | | | + * input ------> [transpose1] ------> [relu] ------> [transpose2] ------> [conv2d] ------> output + * 1x5x5x3 3x5x5x1 3x5x5x1 1x5x5x3 1x5x5x4 + * + */ +TEST(CircleShapeInferencePassTest, original_node_change) +{ + luci::CircleShapeInferencePass pass; + auto g = loco::make_graph(); + + // Have to be packed into lambda to check throw + auto shape_inference_run = [&]() { + while (pass.run(g.get()) == true) + ; + }; + + // Create nodes to make relu traversed first + auto input = g->nodes()->create<luci::CircleInput>(); + auto relu = g->nodes()->create<luci::CircleRelu>(); + auto old_perm = g->nodes()->create<luci::CircleConst>(); + auto transpose1 = g->nodes()->create<luci::CircleTranspose>(); + auto filter = g->nodes()->create<luci::CircleConst>(); + auto bias = g->nodes()->create<luci::CircleConst>(); + auto conv2d = g->nodes()->create<luci::CircleConv2D>(); + auto output = g->nodes()->create<luci::CircleOutput>(); + auto new_perm = g->nodes()->create<luci::CircleConst>(); + auto transpose2 = g->nodes()->create<luci::CircleTranspose>(); + + // Build up initial graph + auto graph_input = g->inputs()->create(); + graph_input->shape({1, 5, 5, 3}); + + input->index(graph_input->index()); + input->shape({1, 5, 5, 3}); + input->shape_status(luci::ShapeStatus::VALID); + + old_perm->dtype(loco::DataType::S32); + old_perm->size<loco::DataType::S32>(4); + old_perm->shape({4}); + old_perm->at<loco::DataType::S32>(0) = 0; + old_perm->at<loco::DataType::S32>(1) = 2; + old_perm->at<loco::DataType::S32>(2) = 1; + old_perm->at<loco::DataType::S32>(3) = 3; + old_perm->shape_status(luci::ShapeStatus::VALID); + + transpose1->a(input); + transpose1->perm(old_perm); + + relu->features(transpose1); + + filter->dtype(loco::DataType::FLOAT32); + filter->size<loco::DataType::FLOAT32>(4 * 1 * 1 * 3); + filter->shape({4, 1, 1, 3}); + filter->shape_status(luci::ShapeStatus::VALID); + + bias->dtype(loco::DataType::FLOAT32); + bias->size<loco::DataType::FLOAT32>(4); + bias->shape({4}); + bias->shape_status(luci::ShapeStatus::VALID); + + conv2d->input(relu); + conv2d->filter(filter); + conv2d->bias(bias); + conv2d->padding(luci::Padding::VALID); + conv2d->stride()->h(1); + conv2d->stride()->w(1); + conv2d->dilation()->h(1); + conv2d->dilation()->w(1); + + output->from(conv2d); + auto graph_output = g->outputs()->create(); + output->index(graph_output->index()); + graph_output->shape({1, 5, 5, 4}); + + ASSERT_NO_THROW(shape_inference_run()); + + // Transform graph + new_perm->dtype(loco::DataType::S32); + new_perm->size<loco::DataType::S32>(4); + new_perm->shape({4}); + new_perm->at<loco::DataType::S32>(0) = 3; + new_perm->at<loco::DataType::S32>(1) = 2; + new_perm->at<loco::DataType::S32>(2) = 1; + new_perm->at<loco::DataType::S32>(3) = 0; + new_perm->shape_status(luci::ShapeStatus::VALID); + + transpose1->perm(new_perm); + + transpose2->a(relu); + transpose2->perm(new_perm); + + conv2d->input(transpose2); + + ASSERT_NO_THROW(shape_inference_run()); + + // Check result of shape inference is correct + ASSERT_EQ(3, transpose1->dim(0).value()); + ASSERT_EQ(5, transpose1->dim(1).value()); + ASSERT_EQ(5, transpose1->dim(2).value()); + ASSERT_EQ(1, transpose1->dim(3).value()); + + ASSERT_EQ(3, relu->dim(0).value()); + ASSERT_EQ(5, relu->dim(1).value()); + ASSERT_EQ(5, relu->dim(2).value()); + ASSERT_EQ(1, relu->dim(3).value()); + + ASSERT_EQ(1, transpose2->dim(0).value()); + ASSERT_EQ(5, transpose2->dim(1).value()); + ASSERT_EQ(5, transpose2->dim(2).value()); + ASSERT_EQ(3, transpose2->dim(3).value()); + + ASSERT_EQ(1, conv2d->dim(0).value()); + ASSERT_EQ(5, conv2d->dim(1).value()); + ASSERT_EQ(5, conv2d->dim(2).value()); + ASSERT_EQ(4, conv2d->dim(3).value()); + + SUCCEED(); +} + +/** + * This test is for checking when imported shape is wrong. + * + * Even "concat1" has wrong shape at first, correct shape should be inferred. + * + * <Initial graph> + * + * 1x1x1x1 + * input1 ------+ 8x7x6x5 + * +-----> [concat1] ------+ + * input2 ------+ (axis=3) | 1x1x2x3 + * 1x1x1x2 +------> [concat2] ------> output + * | (axis=2) + * 1x1x1x3 | + * input3 ------------------------------+ + * + * + * <Expected result> + * + * 1x1x1x1 + * input1 ------+ 1x1x1x3 + * +-----> [concat1] ------+ + * input2 ------+ (axis=3) | 1x1x2x3 + * 1x1x1x2 +------> [concat2] ------> output + * | (axis=2) + * 1x1x1x3 | + * input3 ------------------------------+ + */ +TEST(CircleShapeInferencePassTest, wrong_imported_shape) +{ + luci::CircleShapeInferencePass pass; + auto g = loco::make_graph(); + + // Have to be packed into lambda to check throw + auto shape_inference_run = [&]() { + while (pass.run(g.get()) == true) + ; + }; + + // Create nodes to make concat2 traversed first + auto concat2 = g->nodes()->create<luci::CircleConcatenation>(2); + auto concat1 = g->nodes()->create<luci::CircleConcatenation>(2); + auto input1 = g->nodes()->create<luci::CircleInput>(); + auto input2 = g->nodes()->create<luci::CircleInput>(); + auto input3 = g->nodes()->create<luci::CircleInput>(); + + // Build up initial graph + auto graph_input1 = g->inputs()->create(); + auto graph_input2 = g->inputs()->create(); + auto graph_input3 = g->inputs()->create(); + graph_input1->shape({1, 1, 1, 1}); + graph_input2->shape({1, 1, 1, 2}); + graph_input2->shape({1, 1, 1, 3}); + + input1->index(graph_input1->index()); + input1->shape({1, 1, 1, 1}); + input1->shape_status(luci::ShapeStatus::VALID); + + input2->index(graph_input2->index()); + input2->shape({1, 1, 1, 2}); + input2->shape_status(luci::ShapeStatus::VALID); + + input3->index(graph_input3->index()); + input3->shape({1, 1, 1, 3}); + input3->shape_status(luci::ShapeStatus::VALID); + + concat1->values(0, input1); + concat1->values(1, input2); + concat1->axis(3); + concat1->shape({8, 7, 6, 5}); // Intentionally set wrong shape + concat1->shape_status(luci::ShapeStatus::VALID); + + concat2->values(0, concat1); + concat2->values(1, input3); + concat2->axis(2); + + auto output = g->nodes()->create<luci::CircleOutput>(); + output->from(concat2); + auto graph_output = g->outputs()->create(); + output->index(graph_output->index()); + graph_output->shape({1, 1, 2, 3}); + + ASSERT_NO_THROW(shape_inference_run()); + + // Check result of shape inference is correct + ASSERT_EQ(1, concat1->dim(0).value()); + ASSERT_EQ(1, concat1->dim(1).value()); + ASSERT_EQ(1, concat1->dim(2).value()); + ASSERT_EQ(3, concat1->dim(3).value()); + + ASSERT_EQ(1, concat2->dim(0).value()); + ASSERT_EQ(1, concat2->dim(1).value()); + ASSERT_EQ(2, concat2->dim(2).value()); + ASSERT_EQ(3, concat2->dim(3).value()); + + SUCCEED(); +} + +/** + * This test is for checking that virtual operations which is not used for graph output + * but shape should be exported. + * + * Although "split_out2" is not used for graph output, shape should be inferenced. + * + * <Initial graph> + * + * + * 1x6 +----> [split_out1] ----> output + * input ------> [split] -----+ + * (split_dim=1) +----> [split_out2] + * (num_split=2) + * + * + * <Expected result> + * 1x3 1x3 + * 1x6 +----> [split_out1] ----> output + * input ------> [split] -----+ + * (split_dim=1) +----> [split_out2] + * (num_split=2) 1x3 + */ +TEST(CircleShapeInferencePassTest, not_used_virtual_op) +{ + luci::CircleShapeInferencePass pass; + auto g = loco::make_graph(); + + // Have to be packed into lambda to check throw + auto shape_inference_run = [&]() { + while (pass.run(g.get()) == true) + ; + }; + + // Create nodes + auto input = g->nodes()->create<luci::CircleInput>(); + auto split = g->nodes()->create<luci::CircleSplit>(); + auto split_out1 = g->nodes()->create<luci::CircleSplitOut>(); + auto split_out2 = g->nodes()->create<luci::CircleSplitOut>(); + auto split_dim = g->nodes()->create<luci::CircleConst>(); + + // Build up initial graph + auto graph_input1 = g->inputs()->create(); + graph_input1->shape({1, 6}); + + input->index(graph_input1->index()); + input->shape({1, 6}); + input->shape_status(luci::ShapeStatus::VALID); + + split_dim->dtype(loco::DataType::S32); + split_dim->size<loco::DataType::S32>(1); + split_dim->shape({1}); + split_dim->at<loco::DataType::S32>(0) = 1; + split_dim->shape_status(luci::ShapeStatus::VALID); + + split->split_dim(split_dim); + split->input(input); + split->num_split(2); + + split_out1->input(split); + split_out1->index(0); + + split_out2->input(split); + split_out2->index(1); + + auto output = g->nodes()->create<luci::CircleOutput>(); + output->from(split_out1); + auto graph_output = g->outputs()->create(); + output->index(graph_output->index()); + graph_output->shape({1, 3}); + + ASSERT_NO_THROW(shape_inference_run()); + + // Check result of shape inference is correct + ASSERT_EQ(1, split_out1->dim(0).value()); + ASSERT_EQ(3, split_out1->dim(1).value()); + + ASSERT_EQ(1, split_out2->dim(0).value()); + ASSERT_EQ(3, split_out2->dim(1).value()); + + SUCCEED(); +} diff --git a/compiler/luci/pass/src/CircleTypeInferencePass.cpp b/compiler/luci/pass/src/CircleTypeInferencePass.cpp index 67bd253e0..fb3755ffa 100644 --- a/compiler/luci/pass/src/CircleTypeInferencePass.cpp +++ b/compiler/luci/pass/src/CircleTypeInferencePass.cpp @@ -14,6 +14,8 @@ * limitations under the License. */ +#include "helpers/InferenceCandidates.h" + #include "luci/Pass/CircleTypeInferencePass.h" #include <luci/Service/CircleTypeInference.h> @@ -41,7 +43,7 @@ bool CircleTypeInferencePass::run(loco::Graph *g) luci::tinf::Rule type_infer_rule; bool changed = false; - for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + for (auto node : inference_candidates(g)) { loco::DataType dtype; auto circle_node = loco::must_cast<luci::CircleNode *>(node); diff --git a/compiler/luci/pass/src/CircleTypeInferencePass.test.cpp b/compiler/luci/pass/src/CircleTypeInferencePass.test.cpp new file mode 100644 index 000000000..415424a6f --- /dev/null +++ b/compiler/luci/pass/src/CircleTypeInferencePass.test.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/CircleTypeInferencePass.h" + +#include <gtest/gtest.h> + +TEST(CircleTypeInferencePassTest, name) +{ + luci::CircleTypeInferencePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp new file mode 100644 index 000000000..c9022f122 --- /dev/null +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp @@ -0,0 +1,698 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ConvertNCHWToNHWCPass.h" +#include "CircleOptimizerUtils.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> +#include <luci/Profile/CircleNodeOrigin.h> +#include <luci/Log.h> + +namespace +{ + +enum class DataFormat +{ + NCHW, + NHWC +}; + +/** + * @brief Set annotation for DataFormat (NCHW, NHWC) + * + * @note DataFormatAnnotation will live longer than this Pass (until the + * annotated loco::Node is erased). So, do not use large data in the + * annotation to avoid excessive memory usage. + */ +class DataFormatAnnotation final : public loco::NodeAnnotation +{ +public: + DataFormatAnnotation(const DataFormat &format) : _format{format} + { + // DO NOTHING + } + +public: + const DataFormat &format(void) const { return _format; } + +private: + DataFormat _format; +}; + +void set_data_format(loco::Node *node, const DataFormat &format) +{ + node->annot(std::make_unique<DataFormatAnnotation>(format)); +} + +DataFormat get_data_format(loco::Node *node) +{ + assert(node->annot<DataFormatAnnotation>() != nullptr); + return node->annot<DataFormatAnnotation>()->format(); +} + +bool has_data_format(loco::Node *node) { return node->annot<DataFormatAnnotation>() != nullptr; } + +luci::CircleTranspose *create_4d_transpose(luci::CircleNode *node, + const std::vector<int32_t> indices) +{ + assert(indices.size() == 4); + + auto name = node->name(); + assert(name.length() > 0); + + auto perm = node->graph()->nodes()->create<luci::CircleConst>(); + perm->dtype(loco::DataType::S32); + perm->size<loco::DataType::S32>(4); + perm->rank(1); + perm->dim(0) = 4; + for (uint32_t i = 0; i < 4; i++) + perm->at<loco::DataType::S32>(i) = indices[i]; + perm->shape_status(luci::ShapeStatus::VALID); + + auto make_string = [](const std::vector<int32_t> &nums) { + std::string str; + for (auto num : nums) + { + if (str.length() > 0) + str += "."; + str += std::to_string(num); + } + return str; + }; + + auto str_indices = make_string(indices); + + perm->name(name + "/Transpose_" + str_indices + "/perm"); + + auto trans = node->graph()->nodes()->create<luci::CircleTranspose>(); + trans->perm(perm); + trans->name(name + "/Transpose_" + str_indices); + luci::add_origin(trans, luci::get_origin(node)); + + return trans; +} + +int32_t nchw_axis_to_nhwc(int32_t axis) +{ + uint32_t pos_axis = axis >= 0 ? static_cast<uint32_t>(axis) : static_cast<uint32_t>(axis + 4); + static const uint32_t to_nhwc[4] = {0, 3, 1, 2}; + if (pos_axis > 3) + throw std::runtime_error("Concat axis must be in range [-4, 4)"); + return to_nhwc[pos_axis]; +} + +luci::CircleTranspose *create_post_transpose(luci::CircleNode *node) +{ + return create_4d_transpose(node, {0, 3, 1, 2}); +} + +luci::CircleTranspose *create_pre_transpose(luci::CircleNode *node) +{ + return create_4d_transpose(node, {0, 2, 3, 1}); +} + +uint32_t cal_offset(const loco::TensorShape &dimension, const uint32_t *indices) +{ + return indices[0] * dimension.dim(1).value() * dimension.dim(2).value() * + dimension.dim(3).value() + + indices[1] * dimension.dim(2).value() * dimension.dim(3).value() + + indices[2] * dimension.dim(3).value() + indices[3]; +} + +luci::CircleConst *create_NHWC_paddings(luci::CircleConst *paddings) +{ + // paddings shape is (4,2) (it was checked by is_NCHW) + assert(paddings != nullptr); + assert(paddings->rank() == 2); + assert(paddings->dim(0).value() == 4); + assert(paddings->dim(1).value() == 2); + + // paddings for idx 0~3 are 0 (checked by is_NCHW) + assert(paddings->at<loco::DataType::S32>(0) == 0); + assert(paddings->at<loco::DataType::S32>(1) == 0); + assert(paddings->at<loco::DataType::S32>(2) == 0); + assert(paddings->at<loco::DataType::S32>(3) == 0); + + auto name = paddings->name(); + assert(name.length() > 0); + + auto nhwc_paddings = paddings->graph()->nodes()->create<luci::CircleConst>(); + nhwc_paddings->dtype(loco::DataType::S32); + nhwc_paddings->shape({4, 2}); + nhwc_paddings->shape_status(luci::ShapeStatus::VALID); + nhwc_paddings->size<loco::DataType::S32>(4 * 2); + nhwc_paddings->name(name + "_NHWC"); + + for (uint32_t dim = 0; dim < 4; dim++) + { + for (uint32_t i = 0; i < 2; i++) + { + int32_t data = 0; + + if (dim == 1) + { + // get third dimension (H in NCHW) + data = paddings->at<loco::DataType::S32>(2 * 2 + i); + } + else if (dim == 2) + { + // get fourth dimension (W in NCHW) + data = paddings->at<loco::DataType::S32>(3 * 2 + i); + } + + nhwc_paddings->at<loco::DataType::S32>(dim * 2 + i) = data; + } + } + return nhwc_paddings; +} + +luci::CircleConst *create_NHWC_from_NCHW(luci::CircleConst *constant) +{ + LOGGER(l); + assert(constant->rank() == 4); + + // TODO: Support non-float types + if (constant->dtype() != loco::DataType::FLOAT32) + { + INFO(l) << "Non-float type constant: " << constant->name() << std::endl; + return nullptr; + } + + loco::TensorShape nchw_dimension{constant->dim(0), constant->dim(1), constant->dim(2), + constant->dim(3)}; + loco::TensorShape nhwc_dimension{constant->dim(0), constant->dim(2), constant->dim(3), + constant->dim(1)}; + + auto name = constant->name(); + assert(name.length() > 0); + + auto nhwc_const = constant->graph()->nodes()->create<luci::CircleConst>(); + nhwc_const->dtype(constant->dtype()); + nhwc_const->rank(4); + nhwc_const->dim(0).set(constant->dim(0).value()); + nhwc_const->dim(1).set(constant->dim(2).value()); + nhwc_const->dim(2).set(constant->dim(3).value()); + nhwc_const->dim(3).set(constant->dim(1).value()); + nhwc_const->shape_status(luci::ShapeStatus::VALID); + nhwc_const->size<loco::DataType::FLOAT32>(constant->size<loco::DataType::FLOAT32>()); + nhwc_const->name(name + "_NHWC"); + + for (uint32_t n = 0; n < nchw_dimension.dim(0).value(); n++) + { + for (uint32_t c = 0; c < nchw_dimension.dim(1).value(); c++) + { + for (uint32_t h = 0; h < nchw_dimension.dim(2).value(); h++) + { + for (uint32_t w = 0; w < nchw_dimension.dim(3).value(); w++) + { + uint32_t nchw_indices[4] = {n, c, h, w}; + uint32_t nhwc_indices[4] = {n, h, w, c}; + auto data = + constant->at<loco::DataType::FLOAT32>(cal_offset(nchw_dimension, nchw_indices)); + nhwc_const->at<loco::DataType::FLOAT32>(cal_offset(nhwc_dimension, nhwc_indices)) = data; + } + } + } + } + return nhwc_const; +} + +// NOTE Following conditions can be extended later +// +// Find PAD with an NCHW pattern described below +// - Paddings shape : [4, 2] +// - Paddings value : [[0, 0], [0, 0], [h_t, h_b], [w_t, w_b]]] +bool is_NCHW(const luci::CirclePad *node) +{ + const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings()); + // Non-const paddings is not supported + if (paddings == nullptr) + return false; + + if (paddings->rank() != 2) + return false; + + if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2) + return false; + + // Only check the first two dimensions + for (uint32_t dim = 0; dim < 2; dim++) + { + for (uint32_t i = 0; i < 2; i++) + { + auto data = paddings->at<loco::DataType::S32>(dim * 2 + i); + if (data != 0) + return false; + } + } + + return true; +} + +// NOTE Following conditions can be extended later +// +// Find MUL with an NCHW pattern described below +// - Input (non-constant) shape : [N, C, H, W] +// - Input (constant) shape : [1, C, 1, 1] +// - Output shape : [N, C, H, W] +bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_node, + luci::CircleConst *&multiplier) +{ + auto x = dynamic_cast<luci::CircleConst *>(node->x()); + auto y = dynamic_cast<luci::CircleConst *>(node->y()); + + if (x != nullptr && y == nullptr) + { + pred_node = loco::must_cast<luci::CircleNode *>(node->y()); + multiplier = x; + } + else if (x == nullptr && y != nullptr) + { + pred_node = loco::must_cast<luci::CircleNode *>(node->x()); + multiplier = y; + } + else + { + // Ignore if MUL does not have a multiplier input. + return false; + } + + if (pred_node->rank() != 4) + return false; + + const auto const_rank = multiplier->rank(); + if (const_rank != 4) + return false; + + for (uint32_t i = 0; i < const_rank; i++) + { + if (i != 1 && multiplier->dim(i).value() != 1) + return false; + } + + const auto const_cdim = multiplier->dim(1); + const auto input_cdim = pred_node->dim(1); + const auto output_cdim = node->dim(1); + + if (const_cdim == input_cdim && input_cdim == output_cdim) + return true; + else + return false; +} + +// We assume ADD with const input is NCHW if, +// Input shape: (N, C, H, W) +// Output shape: (N, C, H, W) +// 1. Const shape is (1, C, 1, 1) +// 2. Input, Output, Const have the same C. +bool is_NCHW_with_const(const luci::CircleAdd *node, luci::CircleNode *&pred_node, + luci::CircleConst *&beta) +{ + auto x = dynamic_cast<luci::CircleConst *>(node->x()); + auto y = dynamic_cast<luci::CircleConst *>(node->y()); + + if (x != nullptr && y == nullptr) + { + pred_node = loco::must_cast<luci::CircleNode *>(node->y()); + beta = x; + } + else if (x == nullptr && y != nullptr) + { + pred_node = loco::must_cast<luci::CircleNode *>(node->x()); + beta = y; + } + else + { + // Ignore if ADD does not have a constant input. + return false; + } + + if (pred_node->rank() != 4) + return false; + + const auto const_rank = beta->rank(); + if (const_rank != 4) + return false; + + // Check the shape is (1, C, 1, 1) + for (uint32_t i = 0; i < const_rank; i++) + { + if (i == 1) + continue; + + if (beta->dim(i).value() != 1) + return false; + } + + const auto const_cdim = beta->dim(1); + const auto input_cdim = pred_node->dim(1); + const auto output_cdim = node->dim(1); + + // Check Input, Output, Const have the same channel size + if (const_cdim == input_cdim && input_cdim == output_cdim) + return true; + else + return false; +} + +template <class T> bool convert_unary_features(T *node) +{ + const auto pred_node = loco::must_cast<luci::CircleNode *>(node->features()); + auto pre_trans = create_pre_transpose(node); + pre_trans->a(pred_node); + node->features(pre_trans); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + auto post_trans = create_post_transpose(node); + loco::replace(node).with(post_trans); + + post_trans->a(node); + + return true; +} + +class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool> +{ + // Default + bool visit(luci::CircleNode *node) + { + throw std::runtime_error(node->name() + " is an unsupported operator."); + } + + bool visit(luci::CircleInput *node) + { + const auto n = node->dim(0); + const auto c = node->dim(1); + const auto h = node->dim(2); + const auto w = node->dim(3); + + node->dim(1) = h; + node->dim(2) = w; + node->dim(3) = c; + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + // Insert post-tranpose + auto post_trans = create_post_transpose(node); + loco::replace(node).with(post_trans); + + post_trans->a(node); + + // Update graph input + auto graph_inputs = node->graph()->inputs(); + auto graph_input = graph_inputs->at(node->index()); + graph_input->shape({n, h, w, c}); + + return true; + } + + bool visit(luci::CircleOutput *node) + { + // Insert pre-transpose + auto pre_trans = create_pre_transpose(node); + pre_trans->a(node->from()); + + node->from(pre_trans); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + // Update graph output + const auto n = node->dim(0).value(); + const auto c = node->dim(1).value(); + const auto h = node->dim(2).value(); + const auto w = node->dim(3).value(); + + auto graph_outputs = node->graph()->outputs(); + auto graph_output = graph_outputs->at(node->index()); + graph_output->shape({n, h, w, c}); + + return true; + } + + bool visit(luci::CircleAdd *node) + { + luci::CircleNode *pred_node = nullptr; + luci::CircleConst *beta = nullptr; + + if (is_NCHW_with_const(node, pred_node, beta)) + { + auto pre_trans = create_pre_transpose(node); + pre_trans->a(pred_node); + + auto nhwc_const = create_NHWC_from_NCHW(beta); + if (nhwc_const == nullptr) + return false; + + node->x(pre_trans); + node->y(nhwc_const); + } + else if (beta == nullptr) + { + // Both inputs are not constant. + // In this case, we cannot distinguish NCHW from NHWC, + // so just insert Transpose Ops. + auto pre_trans_x = create_pre_transpose(node); + pre_trans_x->a(node->x()); + node->x(pre_trans_x); + + auto pre_trans_y = create_pre_transpose(node); + pre_trans_y->a(node->y()); + node->y(pre_trans_y); + } + else + { + return false; + } + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + auto post_trans = create_post_transpose(node); + loco::replace(node).with(post_trans); + + post_trans->a(node); + return true; + } + + bool visit(luci::CircleConcatenation *node) + { + const auto num_values = node->numValues(); + for (uint32_t i = 0; i < num_values; i++) + { + auto pred_node = loco::must_cast<luci::CircleNode *>(node->values(i)); + auto pre_trans = create_pre_transpose(node); + pre_trans->a(pred_node); + node->values(i, pre_trans); + } + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + node->axis(nchw_axis_to_nhwc(node->axis())); + + auto post_trans = create_post_transpose(node); + loco::replace(node).with(post_trans); + + post_trans->a(node); + + return true; + } + + bool visit(luci::CircleLeakyRelu *node) + { + return convert_unary_features<luci::CircleLeakyRelu>(node); + } + + bool visit(luci::CircleMul *node) + { + LOGGER(l); + + luci::CircleNode *pred_node = nullptr; + luci::CircleConst *multiplier = nullptr; + + if (is_NCHW_with_const(node, pred_node, multiplier)) + { + auto pre_trans = create_pre_transpose(node); + pre_trans->a(pred_node); + node->x(pre_trans); + + auto nhwc_const = create_NHWC_from_NCHW(multiplier); + node->y(nhwc_const); + } + else if (multiplier == nullptr) + { + // TODO : Implement this case. + INFO(l) << "Not yet implemented. Both inputs of MUL are non-const." << std::endl; + return false; + } + else + { + return false; + } + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + auto post_trans = create_post_transpose(node); + loco::replace(node).with(post_trans); + + post_trans->a(node); + return true; + } + + bool visit(luci::CircleNeg *node) + { + const auto pred_node = loco::must_cast<luci::CircleNode *>(node->x()); + auto pre_trans = create_pre_transpose(node); + pre_trans->a(pred_node); + node->x(pre_trans); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + auto post_trans = create_post_transpose(node); + loco::replace(node).with(post_trans); + + post_trans->a(node); + + return true; + } + + bool visit(luci::CirclePad *node) + { + if (!is_NCHW(node)) + return false; + + const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input()); + auto pre_trans = create_pre_transpose(node); + pre_trans->a(pred_node); + node->input(pre_trans); + + auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings()); + const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings); + node->paddings(nhwc_paddings); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + auto post_trans = create_post_transpose(node); + loco::replace(node).with(post_trans); + + post_trans->a(node); + + return true; + } + + bool visit(luci::CircleRelu *node) { return convert_unary_features<luci::CircleRelu>(node); } + + bool visit(luci::CircleRelu6 *node) { return convert_unary_features<luci::CircleRelu6>(node); } +}; + +} // namespace + +namespace luci +{ + +bool ConvertNCHWToNHWCPass::run(loco::Graph *g) +{ + LOGGER(l); + INFO(l) << "ConvertNCHWToNHWCPass Start" << std::endl; + + // Annotate NCHW operators + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + switch (circle_node->opcode()) + { + // List of supported Ops + case luci::CircleOpcode::CIRCLEINPUT: + if (!_preserve_input && !has_data_format(node)) + { + set_data_format(node, DataFormat::NCHW); + } + break; + case luci::CircleOpcode::CIRCLEOUTPUT: + if (!_preserve_output && !has_data_format(node)) + { + set_data_format(node, DataFormat::NCHW); + } + break; + case luci::CircleOpcode::ADD: + case luci::CircleOpcode::CONCATENATION: + case luci::CircleOpcode::LEAKY_RELU: + case luci::CircleOpcode::MUL: + case luci::CircleOpcode::NEG: + case luci::CircleOpcode::PAD: + case luci::CircleOpcode::RELU: + case luci::CircleOpcode::RELU6: + if (!has_data_format(node)) + { + set_data_format(node, DataFormat::NCHW); + } + break; + default: + break; + } + } + + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (!has_data_format(node)) + { + // Unsupported Op + continue; + } + else if (get_data_format(node) == DataFormat::NHWC) + { + // Already converted to NHWC + continue; + } + else if (has_dynamic_shape(node)) + { + // This pass only works for static-shaped node + INFO(l) << "Skip the node with a dynamic shape." << std::endl; + continue; + } + else + { + ConvertNCHWToNHWC converter; + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + if (circle_node->rank() != 4) + continue; + + if (circle_node->accept(&converter)) + { + set_data_format(node, DataFormat::NHWC); + changed = true; + } + else + { + continue; + } + } + } + + INFO(l) << "ConvertNCHWToNHWCPass End" << std::endl; + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp new file mode 100644 index 000000000..831d5f89a --- /dev/null +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp @@ -0,0 +1,636 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <logo/Phase.h> + +#include "luci/Pass/ConvertNCHWToNHWCPass.h" +#include "luci/Pass/CircleShapeInferencePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +/** + * Graph with a single Op (example: Add). + * + * BEFORE + * - All Ops including Input/Output are NCHW. + * + * [Input] [beta] + * | / + * [Add] + * | + * [Output] + * + * AFTER + * - All Ops including Input/Output are NHWC. + * + * [Input] + * | + * [Transpose] + * | + * [Transpose] [beta] + * | / + * [Add] + * | + * [Transpose] + * | + * [Transpose] + * | + * [Output] + */ +class SimpleGraph +{ +public: + SimpleGraph() = default; + +public: + void init() + { + input = g.nodes()->create<luci::CircleInput>(); + output = g.nodes()->create<luci::CircleOutput>(); + input->name("input"); + output->name("output"); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + graph_input->dtype(loco::DataType::FLOAT32); + input->dtype(loco::DataType::FLOAT32); + output->dtype(loco::DataType::FLOAT32); + graph_output->dtype(loco::DataType::FLOAT32); + + uint32_t channel_size = 16; + graph_input->shape({1, channel_size, 4, 4}); + input->shape({1, channel_size, 4, 4}); + output->shape({1, channel_size, 4, 4}); + graph_output->shape({1, channel_size, 4, 4}); + + auto graph_body = insertGraphBody(input); + output->from(graph_body); + } + + virtual ~SimpleGraph() = default; + +protected: + virtual loco::Node *insertGraphBody(loco::Node *input) = 0; + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleOutput *output = nullptr; +}; + +class AddGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + add = g.nodes()->create<luci::CircleAdd>(); + beta = g.nodes()->create<luci::CircleConst>(); + + add->dtype(loco::DataType::FLOAT32); + beta->dtype(loco::DataType::FLOAT32); + + uint32_t channel_size = 16; + add->shape({1, channel_size, 4, 4}); + beta->shape({1, channel_size, 1, 1}); + + beta->size<loco::DataType::FLOAT32>(channel_size); + for (uint32_t i = 0; i < channel_size; i++) + { + beta->at<loco::DataType::FLOAT32>(i) = i; + } + + add->x(input); + add->y(beta); + + add->name("add"); + beta->name("beta"); + + return add; + } + +public: + luci::CircleAdd *add = nullptr; + luci::CircleConst *beta = nullptr; +}; + +class ConcatenationGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + concat = g.nodes()->create<luci::CircleConcatenation>(2); + concat->values(0, input); + concat->axis(1); + + input2 = g.nodes()->create<luci::CircleConst>(); + input2->dtype(loco::DataType::FLOAT32); + input2->shape({1, 16, 4, 4}); + input2->size<loco::DataType::FLOAT32>(16 * 4 * 4); + for (uint32_t i = 0; i < 16 * 4 * 4; i++) + { + input2->at<loco::DataType::FLOAT32>(i) = i; + } + concat->values(1, input2); + + concat->name("concat"); + input2->name("input2"); + + return concat; + } + +public: + luci::CircleConcatenation *concat = nullptr; + luci::CircleConst *input2 = nullptr; +}; + +class LeakyReluGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + leakyrelu = g.nodes()->create<luci::CircleLeakyRelu>(); + leakyrelu->features(input); + leakyrelu->name("leakyrelu"); + + return leakyrelu; + } + +public: + luci::CircleLeakyRelu *leakyrelu = nullptr; +}; + +class MulGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + mul = g.nodes()->create<luci::CircleMul>(); + multiplier = g.nodes()->create<luci::CircleConst>(); + + mul->dtype(loco::DataType::FLOAT32); + multiplier->dtype(loco::DataType::FLOAT32); + + uint32_t channel_size = 16; + mul->shape({1, channel_size, 4, 4}); + multiplier->shape({1, channel_size, 1, 1}); + + multiplier->size<loco::DataType::FLOAT32>(channel_size); + for (uint32_t i = 0; i < channel_size; i++) + { + multiplier->at<loco::DataType::FLOAT32>(i) = i; + } + + mul->x(input); + mul->y(multiplier); + + mul->name("mul"); + multiplier->name("multiplier"); + + return mul; + } + +public: + luci::CircleMul *mul = nullptr; + luci::CircleConst *multiplier = nullptr; +}; + +class NegGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + neg = g.nodes()->create<luci::CircleNeg>(); + neg->x(input); + neg->name("neg"); + + return neg; + } + +public: + luci::CircleNeg *neg = nullptr; +}; + +class PadGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + pad = g.nodes()->create<luci::CirclePad>(); + paddings = g.nodes()->create<luci::CircleConst>(); + + pad->dtype(loco::DataType::FLOAT32); + paddings->dtype(loco::DataType::S32); + + uint32_t channel_size = 16; + pad->shape({1, channel_size, 4, 4}); + paddings->shape({4, 2}); + + // paddings data (NCHW) + // [[0,0], [0,0], [1,1], [2,2]] + paddings->size<loco::DataType::S32>(8); + for (uint32_t dim = 0; dim < 4; dim++) + { + for (uint32_t i = 0; i < 2; i++) + { + int32_t data = 0; + + if (dim == 2) + data = 1; + else if (dim == 3) + data = 2; + + paddings->at<loco::DataType::S32>(dim * 2 + i) = data; + } + } + + pad->input(input); + pad->paddings(paddings); + + pad->name("pad"); + paddings->name("paddings"); + + return pad; + } + +public: + luci::CirclePad *pad = nullptr; + luci::CircleConst *paddings = nullptr; +}; + +class ReluGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + relu = g.nodes()->create<luci::CircleRelu>(); + relu->features(input); + relu->name("Relu"); + + return relu; + } + +public: + luci::CircleRelu *relu = nullptr; +}; + +class Relu6Graph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + relu6 = g.nodes()->create<luci::CircleRelu6>(); + relu6->features(input); + relu6->name("relu6"); + + return relu6; + } + +public: + luci::CircleRelu6 *relu6 = nullptr; +}; + +void check_pre_trans(loco::Node *node) +{ + auto pre_trans = dynamic_cast<luci::CircleTranspose *>(node); + EXPECT_NE(nullptr, pre_trans); + auto pre_trans_perm = dynamic_cast<luci::CircleConst *>(pre_trans->perm()); + EXPECT_NE(nullptr, pre_trans_perm); + EXPECT_EQ(1, pre_trans_perm->rank()); + EXPECT_EQ(4, pre_trans_perm->dim(0).value()); + EXPECT_EQ(loco::DataType::S32, pre_trans_perm->dtype()); + EXPECT_EQ(0, pre_trans_perm->at<loco::DataType::S32>(0)); + EXPECT_EQ(2, pre_trans_perm->at<loco::DataType::S32>(1)); + EXPECT_EQ(3, pre_trans_perm->at<loco::DataType::S32>(2)); + EXPECT_EQ(1, pre_trans_perm->at<loco::DataType::S32>(3)); +} + +void check_post_trans(loco::Node *node) +{ + auto post_trans = dynamic_cast<luci::CircleTranspose *>(node); + EXPECT_NE(nullptr, post_trans); + auto post_trans_perm = dynamic_cast<luci::CircleConst *>(post_trans->perm()); + EXPECT_NE(nullptr, post_trans_perm); + EXPECT_EQ(1, post_trans_perm->rank()); + EXPECT_EQ(4, post_trans_perm->dim(0).value()); + EXPECT_EQ(loco::DataType::S32, post_trans_perm->dtype()); + EXPECT_EQ(0, post_trans_perm->at<loco::DataType::S32>(0)); + EXPECT_EQ(3, post_trans_perm->at<loco::DataType::S32>(1)); + EXPECT_EQ(1, post_trans_perm->at<loco::DataType::S32>(2)); + EXPECT_EQ(2, post_trans_perm->at<loco::DataType::S32>(3)); +} + +void run_phase(loco::Graph *g, bool preserve_input, bool preserve_output) +{ + logo::Phase phase; + + // Default passes. + phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); + + // Pass to test + phase.emplace_back( + std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output)); + + logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g}; + phase_runner.run(phase); +} + +} // namespace + +TEST(ConvertNCHWToNHWCPassTest, name) +{ + luci::ConvertNCHWToNHWCPass pass(false, false); + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(ConvertNCHWToNHWC, Add) +{ + AddGraph g; + g.init(); + + run_phase(&g.g, false, false); + + auto input_succs = loco::succs(g.input); + EXPECT_EQ(1, input_succs.size()); + check_post_trans(*input_succs.begin()); + + check_pre_trans(g.add->x()); + + auto add_succs = loco::succs(g.add); + EXPECT_EQ(1, add_succs.size()); + check_post_trans(*add_succs.begin()); + + uint32_t channel_size = 16; + auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y()); + EXPECT_NE(nullptr, new_beta); + EXPECT_EQ(4, new_beta->rank()); + EXPECT_EQ(1, new_beta->dim(0).value()); + EXPECT_EQ(1, new_beta->dim(1).value()); + EXPECT_EQ(1, new_beta->dim(2).value()); + EXPECT_EQ(channel_size, new_beta->dim(3).value()); + + check_pre_trans(g.output->from()); +} + +TEST(ConvertNCHWToNHWC, Concatenation) +{ + ConcatenationGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.concat->values(0)); + check_pre_trans(g.concat->values(1)); + + auto concat_succs = loco::succs(g.concat); + EXPECT_EQ(1, concat_succs.size()); + check_post_trans(*concat_succs.begin()); + + // Check concat shape, axis + EXPECT_EQ(1, g.concat->dim(0).value()); + EXPECT_EQ(4, g.concat->dim(1).value()); + EXPECT_EQ(4, g.concat->dim(2).value()); + EXPECT_EQ(32, g.concat->dim(3).value()); + EXPECT_EQ(3, g.concat->axis()); +} + +TEST(ConvertNCHWToNHWC, LeakyRelu) +{ + LeakyReluGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.leakyrelu->features()); + + auto leakyrelu_succs = loco::succs(g.leakyrelu); + EXPECT_EQ(1, leakyrelu_succs.size()); + check_post_trans(*leakyrelu_succs.begin()); + + // Check leakyrelu shape + EXPECT_EQ(1, g.leakyrelu->dim(0).value()); + EXPECT_EQ(4, g.leakyrelu->dim(1).value()); + EXPECT_EQ(4, g.leakyrelu->dim(2).value()); + EXPECT_EQ(16, g.leakyrelu->dim(3).value()); +} + +TEST(ConvertNCHWToNHWC, Mul) +{ + MulGraph g; + g.init(); + + run_phase(&g.g, false, false); + + auto input_succs = loco::succs(g.input); + EXPECT_EQ(1, input_succs.size()); + check_post_trans(*input_succs.begin()); + + check_pre_trans(g.mul->x()); + + auto mul_succs = loco::succs(g.mul); + EXPECT_EQ(1, mul_succs.size()); + check_post_trans(*mul_succs.begin()); + + uint32_t channel_size = 16; + auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y()); + EXPECT_NE(nullptr, new_multiplier); + EXPECT_EQ(4, new_multiplier->rank()); + EXPECT_EQ(1, new_multiplier->dim(0).value()); + EXPECT_EQ(1, new_multiplier->dim(1).value()); + EXPECT_EQ(1, new_multiplier->dim(2).value()); + EXPECT_EQ(channel_size, new_multiplier->dim(3).value()); + + check_pre_trans(g.output->from()); +} + +TEST(ConvertNCHWToNHWC, Neg) +{ + NegGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.neg->x()); + + auto neg_succs = loco::succs(g.neg); + EXPECT_EQ(1, neg_succs.size()); + check_post_trans(*neg_succs.begin()); + + // Check leakyrelu shape + EXPECT_EQ(1, g.neg->dim(0).value()); + EXPECT_EQ(4, g.neg->dim(1).value()); + EXPECT_EQ(4, g.neg->dim(2).value()); + EXPECT_EQ(16, g.neg->dim(3).value()); +} + +TEST(ConvertNCHWToNHWC, Pad) +{ + PadGraph g; + g.init(); + + run_phase(&g.g, false, false); + + auto input_succs = loco::succs(g.input); + EXPECT_EQ(1, input_succs.size()); + check_post_trans(*input_succs.begin()); + + check_pre_trans(g.pad->input()); + + auto pad_succs = loco::succs(g.pad); + EXPECT_EQ(1, pad_succs.size()); + check_post_trans(*pad_succs.begin()); + + auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings()); + EXPECT_NE(nullptr, new_paddings); + EXPECT_EQ(2, new_paddings->rank()); + EXPECT_EQ(4, new_paddings->dim(0).value()); + EXPECT_EQ(2, new_paddings->dim(1).value()); + EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0)); + EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1)); + EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2)); + EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3)); + EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4)); + EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5)); + EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6)); + EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7)); + + check_pre_trans(g.output->from()); +} + +TEST(ConvertNCHWToNHWC, Unknown_Shape_NEG) +{ + AddGraph g; + g.init(); + + // Unknown shape + g.input->dim(0).unset(); + g.add->dim(0).unset(); + g.output->dim(0).unset(); + + luci::ConvertNCHWToNHWCPass pass(false, false); + EXPECT_EQ(false, pass.run(&g.g)); +} + +TEST(ConvertNCHWToNHWC, Preserve_Input_Output) +{ + // Preserve input + { + AddGraph g; + g.init(); + + run_phase(&g.g, true, false); + + // Check input shape + EXPECT_EQ(1, g.input->dim(0).value()); + EXPECT_EQ(16, g.input->dim(1).value()); + EXPECT_EQ(4, g.input->dim(2).value()); + EXPECT_EQ(4, g.input->dim(3).value()); + + // Check output shape + EXPECT_EQ(1, g.output->dim(0).value()); + EXPECT_EQ(4, g.output->dim(1).value()); + EXPECT_EQ(4, g.output->dim(2).value()); + EXPECT_EQ(16, g.output->dim(3).value()); + } + + // Preserve output + { + AddGraph g; + g.init(); + + run_phase(&g.g, false, true); + + // Check input shape + EXPECT_EQ(1, g.input->dim(0).value()); + EXPECT_EQ(4, g.input->dim(1).value()); + EXPECT_EQ(4, g.input->dim(2).value()); + EXPECT_EQ(16, g.input->dim(3).value()); + + // Check output shape + EXPECT_EQ(1, g.output->dim(0).value()); + EXPECT_EQ(16, g.output->dim(1).value()); + EXPECT_EQ(4, g.output->dim(2).value()); + EXPECT_EQ(4, g.output->dim(3).value()); + } + + // Preserve both input and output + { + AddGraph g; + g.init(); + + run_phase(&g.g, true, true); + + // Check input shape + EXPECT_EQ(1, g.input->dim(0).value()); + EXPECT_EQ(16, g.input->dim(1).value()); + EXPECT_EQ(4, g.input->dim(2).value()); + EXPECT_EQ(4, g.input->dim(3).value()); + + // Check output shape + EXPECT_EQ(1, g.output->dim(0).value()); + EXPECT_EQ(16, g.output->dim(1).value()); + EXPECT_EQ(4, g.output->dim(2).value()); + EXPECT_EQ(4, g.output->dim(3).value()); + } +} + +TEST(ConvertNCHWToNHWC, Relu) +{ + ReluGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.relu->features()); + + auto relu_succs = loco::succs(g.relu); + EXPECT_EQ(1, relu_succs.size()); + check_post_trans(*relu_succs.begin()); + + // Check relu shape + EXPECT_EQ(1, g.relu->dim(0).value()); + EXPECT_EQ(4, g.relu->dim(1).value()); + EXPECT_EQ(4, g.relu->dim(2).value()); + EXPECT_EQ(16, g.relu->dim(3).value()); +} + +TEST(ConvertNCHWToNHWC, Relu6) +{ + Relu6Graph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.relu6->features()); + + auto relu6_succs = loco::succs(g.relu6); + EXPECT_EQ(1, relu6_succs.size()); + check_post_trans(*relu6_succs.begin()); + + // Check relu6 shape + EXPECT_EQ(1, g.relu6->dim(0).value()); + EXPECT_EQ(4, g.relu6->dim(1).value()); + EXPECT_EQ(4, g.relu6->dim(2).value()); + EXPECT_EQ(16, g.relu6->dim(3).value()); +} diff --git a/compiler/luci/pass/src/FoldAddV2Pass.cpp b/compiler/luci/pass/src/FoldAddV2Pass.cpp new file mode 100644 index 000000000..20c1022f8 --- /dev/null +++ b/compiler/luci/pass/src/FoldAddV2Pass.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldAddV2Pass.h" + +#include <luci/IR/CircleNodes.h> + +#include <iostream> + +namespace +{ + +bool same_shape(const luci::CircleConst *x, const luci::CircleConst *y) +{ + if (x->rank() != y->rank()) + return false; + + for (uint32_t i = 0; i < x->rank(); i++) + { + if (!(x->dim(i) == y->dim(i))) + return false; + } + + return true; +} + +/** + * Fold AddV2 to const if both inputs are const + **/ +template <loco::DataType T> bool fold_add_v2(luci::CircleCustom *add_v2) +{ + // This should hold for AddV2 + if (add_v2->numInputs() != 2) + return false; + + // Check first input is const + auto x = dynamic_cast<luci::CircleConst *>(add_v2->inputs(0)); + if (not x) + return false; + + // Check second input is const + auto y = dynamic_cast<luci::CircleConst *>(add_v2->inputs(1)); + if (not y) + return false; + + if (x->dtype() != y->dtype()) + return false; + + if (!same_shape(x, y)) + return false; + + auto name_x = x->name(); + auto name_y = y->name(); + assert(name_x.length() > 0); + assert(name_y.length() > 0); + auto constant = add_v2->graph()->nodes()->create<luci::CircleConst>(); + constant->dtype(x->dtype()); + constant->rank(x->rank()); + for (uint32_t i = 0; i < x->rank(); i++) + constant->dim(i).set(x->dim(i).value()); + + const auto size = x->size<T>(); + constant->size<T>(size); + for (uint32_t i = 0; i < size; i++) + constant->at<T>(i) = x->at<T>(i) + y->at<T>(i); + + constant->shape_status(luci::ShapeStatus::VALID); + constant->name(name_x + ";" + name_y); + + for (auto succ : loco::succs(add_v2)) + { + auto custom_out = loco::must_cast<luci::CircleCustomOut *>(succ); + loco::replace(custom_out).with(constant); + } + + return true; +} + +} // namespace + +namespace luci +{ + +/** + * Constant Folding for AddV2 Op + **/ +bool FoldAddV2Pass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto custom = dynamic_cast<luci::CircleCustom *>(node)) + { + if (custom->custom_code() == "AddV2") + { + // TODO: Support more data types + if (custom->dtype() == loco::DataType::S64) + { + if (fold_add_v2<loco::DataType::S64>(custom)) + changed = true; + } + } + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FoldAddV2Pass.test.cpp b/compiler/luci/pass/src/FoldAddV2Pass.test.cpp new file mode 100644 index 000000000..438d7f077 --- /dev/null +++ b/compiler/luci/pass/src/FoldAddV2Pass.test.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldAddV2Pass.h" +#include "PassTestGraphs.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +/** + * Graph has an AddV2 Op with constant inputs + * + * BEFORE + * + * [CircleConst] [CircleConst] + * | | + * [CircleCustom (AddV2)] + * | + * [CircleCustomOut] + * + * AFTER + * + * [CircleConst] + */ +template <loco::DataType T> class FoldAddV2Test : public luci::ConstantFoldingAddTestGraph +{ +public: + FoldAddV2Test(std::initializer_list<uint32_t> shape) : luci::ConstantFoldingAddTestGraph(shape, T) + { + _addV2 = _g.nodes()->create<luci::CircleCustom>(2, 1); + _x = _g.nodes()->create<luci::CircleConst>(); + _y = _g.nodes()->create<luci::CircleConst>(); + _addV2_out = _g.nodes()->create<luci::CircleCustomOut>(); + + _addV2->dtype(T); + _x->dtype(T); + _y->dtype(T); + _addV2_out->dtype(T); + + _addV2->shape(shape); + _x->shape(shape); + _y->shape(shape); + _addV2_out->shape(shape); + + uint32_t num_elems = 1; + for (auto dim = shape.begin(); dim != shape.end(); dim++) + num_elems *= *dim; + + _x->size<T>(num_elems); + _y->size<T>(num_elems); + + for (uint32_t i = 0; i < num_elems; i++) + { + _x->at<T>(i) = i + 1; + _y->at<T>(i) = i + 1; + } + + _addV2->custom_code("AddV2"); + _addV2->inputs(0, _x); + _addV2->inputs(1, _y); + _addV2_out->input(_addV2); + + _addV2->name("addV2"); + _x->name("x"); + _y->name("y"); + } + + loco::Node *createFoldedPattern() override { return _addV2_out; } + + virtual ~FoldAddV2Test() = default; + +protected: + luci::CircleCustom *_addV2 = nullptr; + luci::CircleCustomOut *_addV2_out = nullptr; + luci::CircleConst *_x = nullptr; + luci::CircleConst *_y = nullptr; +}; + +class FoldS64AddV2Test : public FoldAddV2Test<loco::DataType::S64>, public ::testing::Test +{ +public: + FoldS64AddV2Test() : FoldAddV2Test<loco::DataType::S64>({3}) {} + + virtual void SetUp() { init(); } +}; + +} // namespace + +TEST(FoldAddV2PassTest, name) +{ + luci::FoldAddV2Pass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(FoldS64AddV2Test, fold_addV2) +{ + luci::FoldAddV2Pass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Check type, shape, values of folded const + EXPECT_EQ(loco::DataType::S64, folded_const->dtype()); + EXPECT_EQ(1, folded_const->rank()); + EXPECT_EQ(3, folded_const->dim(0).value()); + EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(0)); + EXPECT_EQ(4, folded_const->at<loco::DataType::S64>(1)); + EXPECT_EQ(6, folded_const->at<loco::DataType::S64>(2)); +} + +TEST_F(FoldS64AddV2Test, input_type_mismatch_NEG) +{ + _x->dtype(loco::DataType::S32); + + luci::FoldAddV2Pass pass; + EXPECT_FALSE(pass.run(graph())); +} diff --git a/compiler/luci/pass/src/FoldCastPass.cpp b/compiler/luci/pass/src/FoldCastPass.cpp new file mode 100644 index 000000000..00b86fe48 --- /dev/null +++ b/compiler/luci/pass/src/FoldCastPass.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldCastPass.h" + +#include <luci/IR/CircleNodes.h> + +namespace +{ + +luci::CircleConst *cast_const(luci::CircleConst *node, loco::DataType from_dtype, + loco::DataType to_dtype) +{ + assert(node->dtype() == from_dtype); + + auto name = node->name(); + assert(name.length() > 0); + auto constant = node->graph()->nodes()->create<luci::CircleConst>(); + constant->dtype(to_dtype); + constant->rank(node->rank()); + uint32_t num_elems = 1; + for (uint32_t i = 0; i < node->rank(); i++) + { + constant->dim(i).set(node->dim(i).value()); + num_elems *= node->dim(i).value(); + } + + constant->shape_status(luci::ShapeStatus::VALID); + + // TODO: Support more data types + if (from_dtype == loco::DataType::S64) + { + if (to_dtype == loco::DataType::S32) + { + constant->size<loco::DataType::S32>(num_elems); + for (uint32_t i = 0; i < num_elems; i++) + constant->at<loco::DataType::S32>(i) = + static_cast<int32_t>(node->at<loco::DataType::S64>(i)); + + constant->name(name + "_S32"); + return constant; + } + return nullptr; + } + + return nullptr; +} + +/** + * Fold Cast to const if it has const input + **/ +bool fold_cast(luci::CircleCast *cast) +{ + // Check cast has const input + auto const_x = dynamic_cast<luci::CircleConst *>(cast->x()); + if (not const_x) + return false; + + const auto in_dtype = const_x->dtype(); + const auto out_dtype = cast->dtype(); + + auto casted_const = cast_const(const_x, in_dtype, out_dtype); + if (not casted_const) + return false; + + loco::replace(cast).with(casted_const); + + return true; +} + +} // namespace + +namespace luci +{ + +/** + * Constant Folding for Cast Op + **/ +bool FoldCastPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto cast = dynamic_cast<luci::CircleCast *>(node)) + { + if (fold_cast(cast)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FoldCastPass.test.cpp b/compiler/luci/pass/src/FoldCastPass.test.cpp new file mode 100644 index 000000000..5911adf11 --- /dev/null +++ b/compiler/luci/pass/src/FoldCastPass.test.cpp @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldCastPass.h" +#include "PassTestGraphs.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +template <loco::DataType FromT, loco::DataType ToT> +class FoldCastTest : public luci::ConstantFoldingAddTestGraph +{ +public: + FoldCastTest(std::initializer_list<uint32_t> shape) + : luci::ConstantFoldingAddTestGraph(shape, ToT) + { + _cast = _g.nodes()->create<luci::CircleCast>(); + _x = _g.nodes()->create<luci::CircleConst>(); + + _cast->dtype(ToT); + _x->dtype(FromT); + + _cast->shape(shape); + _x->shape(shape); + + uint32_t num_elems = 1; + for (auto dim = shape.begin(); dim != shape.end(); dim++) + num_elems *= *dim; + + _x->size<FromT>(num_elems); + for (uint32_t i = 0; i < num_elems; i++) + _x->at<FromT>(i) = i + 1; + + _cast->x(_x); + + _cast->name("cast"); + _x->name("x"); + } + + loco::Node *createFoldedPattern() override { return _cast; } + +protected: + luci::CircleCast *_cast = nullptr; + luci::CircleConst *_x = nullptr; +}; + +/** + * Graph that has a Cast Op with constant input + * + * BEFORE + * + * [CircleConst] + * | + * [Cast] + * + * AFTER + * + * [CircleConst] + * + */ +class FoldS64ToS32CastTest : public FoldCastTest<loco::DataType::S64, loco::DataType::S32>, + public ::testing::Test +{ +public: + FoldS64ToS32CastTest() : FoldCastTest<loco::DataType::S64, loco::DataType::S32>({3}) {} + + virtual void SetUp() { init(); } +}; + +} // namespace + +TEST(FoldCastPassTest, name) +{ + luci::FoldCastPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(FoldS64ToS32CastTest, fold_cast_s64_to_s32) +{ + luci::FoldCastPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Check type, shape, values of folded const + EXPECT_EQ(loco::DataType::S32, folded_const->dtype()); + EXPECT_EQ(1, folded_const->rank()); + EXPECT_EQ(3, folded_const->dim(0).value()); + EXPECT_EQ(1, folded_const->at<loco::DataType::S32>(0)); + EXPECT_EQ(2, folded_const->at<loco::DataType::S32>(1)); + EXPECT_EQ(3, folded_const->at<loco::DataType::S32>(2)); +} diff --git a/compiler/luci/pass/src/FoldDequantizePass.cpp b/compiler/luci/pass/src/FoldDequantizePass.cpp index 01c04f478..3dd4f8cea 100644 --- a/compiler/luci/pass/src/FoldDequantizePass.cpp +++ b/compiler/luci/pass/src/FoldDequantizePass.cpp @@ -17,8 +17,7 @@ #include "luci/Pass/FoldDequantizePass.h" #include <luci/IR/CircleNodes.h> - -#include <loco/Service/TypeInference.h> +#include <luci/Profile/CircleNodeOrigin.h> namespace { @@ -51,6 +50,8 @@ luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node) throw std::runtime_error("Given constant node has no quantization parameter"); } + auto name = const_node->name(); + assert(name.length() > 0); auto g = const_node->graph(); auto new_const_node = g->nodes()->create<luci::CircleConst>(); @@ -64,6 +65,7 @@ luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node) } new_const_node->size<loco::DataType::FLOAT32>(dim_size); new_const_node->shape_status(luci::ShapeStatus::VALID); + new_const_node->name(name + "_DQ"); const int32_t q_dim = const_node->quantparam()->quantized_dimension; const int32_t q_dim_value = const_node->dim(q_dim).value(); @@ -81,8 +83,8 @@ luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node) qd = 0; new_const_node->at<loco::DataType::FLOAT32>(i) = - (float)(const_node->at<loco::DataType::S8>(i) - const_node->quantparam()->zerop.at(qd)) * - const_node->quantparam()->scale.at(qd); + (float)(const_node->at<loco::DataType::S8>(i) - const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); } } else @@ -94,9 +96,9 @@ luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node) qd = 0; new_const_node->at<loco::DataType::FLOAT32>(i) = - (float)((int)const_node->at<loco::DataType::U8>(i) - - const_node->quantparam()->zerop.at(qd)) * - const_node->quantparam()->scale.at(qd); + (float)((int)const_node->at<loco::DataType::U8>(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); } } @@ -192,6 +194,8 @@ bool FoldDequantizePass::run(loco::Graph *g) if (replace_const_node(const_node_user, const_node)) { loco::replace(dequant).with(const_node_user); + luci::add_origin(loco::must_cast<luci::CircleNode *>(const_node_user), + luci::get_origin(dequant)); changed = true; } } diff --git a/compiler/luci/service/src/Nodes/CircleOutput.cpp b/compiler/luci/pass/src/FoldDequantizePass.test.cpp index d4c8da2d8..d82a7bc87 100644 --- a/compiler/luci/service/src/Nodes/CircleOutput.cpp +++ b/compiler/luci/pass/src/FoldDequantizePass.test.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,13 @@ * limitations under the License. */ -#include <luci/Service/CircleShapeSignatureInference.h> +#include "luci/Pass/FoldDequantizePass.h" -namespace luci -{ +#include <gtest/gtest.h> -ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutput *node) +TEST(FoldDequantizePassTest, name) { - return input_arg_signature(node, 0); + luci::FoldDequantizePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); } - -} // namespace luci diff --git a/compiler/luci/pass/src/FoldSparseToDensePass.cpp b/compiler/luci/pass/src/FoldSparseToDensePass.cpp new file mode 100644 index 000000000..0c6fc43ed --- /dev/null +++ b/compiler/luci/pass/src/FoldSparseToDensePass.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldSparseToDensePass.h" +#include "CircleOptimizerUtils.h" + +#include <luci/IR/CircleNodes.h> + +namespace +{ + +/** + * Fold to const if + * + * 1. indices has 0-sized static shape such as [0] + * (i.e., output is filled with default value) + * 2. default_value: const scalar + * 3. output_shape: const + * + * TODO: Support more general patterns + **/ +template <loco::DataType IndexT, loco::DataType ValueT> +bool fold_sparse_to_dense(luci::CircleSparseToDense *stod) +{ + const auto indices = loco::must_cast<luci::CircleNode *>(stod->indices()); + const auto default_value = loco::must_cast<luci::CircleConst *>(stod->default_value()); + const auto output_shape = loco::must_cast<luci::CircleConst *>(stod->output_shape()); + + bool has_zero = false; + for (uint32_t i = 0; i < indices->rank(); i++) + { + if (indices->dim(i).known() && indices->dim(i).value() == 0) + has_zero = true; + } + if (!has_zero) + return false; + + if (default_value->rank() != 0 || default_value->size<ValueT>() != 1) + return false; + + auto rank = output_shape->size<IndexT>(); + std::vector<uint32_t> shape; + for (uint32_t i = 0; i < rank; i++) + { + auto dim = output_shape->at<IndexT>(i); + assert(dim >= 0 && dim <= std::numeric_limits<uint32_t>::max()); + if (!(dim >= 0 && dim <= std::numeric_limits<uint32_t>::max())) + return false; + + shape.push_back(dim); + } + + auto name = stod->name(); + assert(name.length() > 0); + auto constant = stod->graph()->nodes()->create<luci::CircleConst>(); + constant->dtype(default_value->dtype()); + constant->rank(rank); + uint32_t dim_size = 1; + for (uint32_t i = 0; i < rank; i++) + { + constant->dim(i).set(shape[i]); + dim_size *= shape[i]; + } + + constant->size<ValueT>(dim_size); + const auto value = default_value->scalar<ValueT>(); + for (uint32_t i = 0; i < dim_size; i++) + constant->at<ValueT>(i) = value; + + constant->shape_status(luci::ShapeStatus::VALID); + constant->name(name + "_D"); + + loco::replace(stod).with(constant); + + return true; +} + +bool fold_sparse_to_dense(luci::CircleSparseToDense *stod) +{ + auto indices = loco::must_cast<luci::CircleNode *>(stod->indices()); + auto default_value = dynamic_cast<luci::CircleConst *>(stod->default_value()); + if (not default_value) + return false; + + auto output_shape = dynamic_cast<luci::CircleConst *>(stod->output_shape()); + if (not output_shape) + return false; + + // Illegal input check + if (indices->dtype() != output_shape->dtype()) + throw std::runtime_error("indices and output_shape of SparseToDense must have the same dtype"); + + // TODO: Support more data types + if (indices->dtype() == loco::DataType::S64) + { + if (default_value->dtype() == loco::DataType::S64) + { + return fold_sparse_to_dense<loco::DataType::S64, loco::DataType::S64>(stod); + } + } + return false; +} + +} // namespace + +namespace luci +{ + +/** + * Constant Folding for SparseToDense Op + **/ +bool FoldSparseToDensePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto stod = dynamic_cast<luci::CircleSparseToDense *>(node)) + { + if (fold_sparse_to_dense(stod)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FoldSparseToDensePass.test.cpp b/compiler/luci/pass/src/FoldSparseToDensePass.test.cpp new file mode 100644 index 000000000..7c6dcb033 --- /dev/null +++ b/compiler/luci/pass/src/FoldSparseToDensePass.test.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldSparseToDensePass.h" +#include "PassTestGraphs.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +/** + * Graph that has a SparseToDense Op with zero-sized indices + * + * BEFORE + * - shape of indices: [0,1] + * - output_shape: [3] + * - default_value: scalar 2 + * + * [indices] [output_shape] [values] [default_value] + * | | | | + * +------[SparseToDense]------+ + * + * AFTER + * + * [Const] (shape: [3], values: [2, 2, 2]) + * + */ +class S64SparseToDenseZeroIndicesTest : public luci::ConstantFoldingAddTestGraph, + public ::testing::Test +{ +public: + S64SparseToDenseZeroIndicesTest() : luci::ConstantFoldingAddTestGraph({3}, loco::DataType::S64) {} + + virtual void SetUp() { init(); } + + loco::Node *createFoldedPattern() override + { + _stod = _g.nodes()->create<luci::CircleSparseToDense>(); + _indices = _g.nodes()->create<luci::CircleConst>(); + _output_shape = _g.nodes()->create<luci::CircleConst>(); + _values = _g.nodes()->create<luci::CircleConst>(); + _default_value = _g.nodes()->create<luci::CircleConst>(); + + _stod->dtype(loco::DataType::S64); + _indices->dtype(loco::DataType::S64); + _output_shape->dtype(loco::DataType::S64); + _values->dtype(loco::DataType::S64); + _default_value->dtype(loco::DataType::S64); + + _indices->shape({0, 1}); + _output_shape->shape({1}); + _values->shape({0}); + _default_value->rank(0); + + _indices->size<loco::DataType::S64>(0); + _output_shape->size<loco::DataType::S64>(1); + _output_shape->at<loco::DataType::S64>(0) = 3; + _values->size<loco::DataType::S64>(0); + _default_value->size<loco::DataType::S64>(1); + _default_value->at<loco::DataType::S64>(0) = 2; + + _stod->indices(_indices); + _stod->output_shape(_output_shape); + _stod->values(_values); + _stod->default_value(_default_value); + + _stod->name("stod"); + _indices->name("indices"); + _output_shape->name("output_shape"); + _values->name("values"); + _default_value->name("default_value"); + + return _stod; + } + +protected: + luci::CircleSparseToDense *_stod = nullptr; + luci::CircleConst *_indices = nullptr; + luci::CircleConst *_output_shape = nullptr; + luci::CircleConst *_values = nullptr; + luci::CircleConst *_default_value = nullptr; +}; + +} // namespace + +TEST(FoldSparseToDensePassTest, name) +{ + luci::FoldSparseToDensePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(S64SparseToDenseZeroIndicesTest, fold_stod_with_zero_indices) +{ + luci::FoldSparseToDensePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Chec type, shape, values of folded const + EXPECT_EQ(loco::DataType::S64, folded_const->dtype()); + EXPECT_EQ(1, folded_const->rank()); + EXPECT_EQ(3, folded_const->dim(0).value()); + EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(0)); + EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(1)); + EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(2)); +} + +TEST_F(S64SparseToDenseZeroIndicesTest, illegal_input_NEG) +{ + _indices->dtype(loco::DataType::S32); + + luci::FoldSparseToDensePass pass; + EXPECT_ANY_THROW(pass.run(graph())); +} diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp new file mode 100644 index 000000000..2c990f0a5 --- /dev/null +++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ForwardReshapeToUnaryOpPass.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> +#include <luci/Log.h> +#include <luci/Profile/CircleNodeOrigin.h> +#include <luci/Service/CircleShapeInference.h> +#include <luci/Service/Nodes/CircleConst.h> + +namespace +{ + +luci::CircleReshape *as_reshape(loco::Node *node) +{ + return dynamic_cast<luci::CircleReshape *>(node); +} + +luci::CircleConst *clone_shape(luci::CircleReshape *reshape) +{ + const auto shape = dynamic_cast<luci::CircleConst *>(reshape->shape()); + // only support CircleConst for now + if (shape == nullptr) + return nullptr; + + // NOTE tflite and circle only supports S32 + // TODO just check with assert() after import handles this + auto dtype = shape->dtype(); + if (dtype != loco::DataType::S32) + return nullptr; + + return luci::clone(shape); +} + +void copy_shape(luci::CircleReshape *reshape, luci::CircleReshape *new_reshape) +{ + auto ns_rank = reshape->newShape()->rank(); + new_reshape->newShape()->rank(ns_rank); + for (uint32_t r = 0; r < ns_rank; ++r) + new_reshape->newShape()->dim(r) = reshape->newShape()->dim(r); +} + +bool forward_reshape(luci::CircleReshape *reshape, luci::CircleNeg *neg) +{ + assert(reshape != nullptr); + assert(neg != nullptr); + + luci::CircleConst *cloned_shape = clone_shape(reshape); + if (cloned_shape == nullptr) + return false; + + auto name = reshape->name(); + assert(name.length() > 0); + loco::Graph *graph = neg->graph(); + // create reshape placed after neg + luci::CircleReshape *new_reshape = graph->nodes()->create<luci::CircleReshape>(); + copy_shape(reshape, new_reshape); + new_reshape->shape(cloned_shape); + new_reshape->name(name + "_C"); + luci::add_origin(new_reshape, luci::get_origin(reshape)); + + // reconnect network + loco::replace(neg).with(new_reshape); + neg->x(reshape->tensor()); + new_reshape->tensor(neg); + + // Do shape inference for this node again. + neg->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; +} + +class ForwardReshape final : public luci::CircleNodeMutableVisitor<bool> +{ +protected: + bool visit(luci::CircleNode *node) + { + LOGGER(l); + INFO(l) << "ForwardReshape: Unsupported operator: " << node->name() << std::endl; + return false; + } + + bool visit(luci::CircleNeg *node) + { + auto reshape = as_reshape(node->x()); + if (reshape == nullptr) + return false; + return forward_reshape(reshape, node); + } + + // TODO add more unary operators +}; + +} // namespace + +namespace luci +{ + +/** + * BEFORE + * | + * [CircleNode] [CircleConst] + * | / + * [CircleReshape] + * / | + * [CircleNode] [(UnaryOp)] + * | | \ + * | | [CircleNode] + * | | | + * + * UnaryOp: CircleNeg, ... + * + * AFTER + * | + * [CircleConst] [CircleNode] + * | / | + * [CircleReshape] [(UnaryOp)] [CircleConst] + * | | / + * [CircleNode] [CircleReshape] + * | | \ + * | | [CircleNode] + * | | | + * + * Note: new [CircleReshape] after [(UnaryOp)] added + */ +bool ForwardReshapeToUnaryOpPass::run(loco::Graph *g) +{ + bool changed = false; + ForwardReshape forward; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + if (circle_node->accept(&forward)) + changed = true; + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp new file mode 100644 index 000000000..2593a014c --- /dev/null +++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ForwardReshapeToUnaryOpPass.h" +#include "luci/Pass/CircleShapeInferencePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <luci/test/TestIOGraph.h> + +#include <gtest/gtest.h> + +#include <vector> + +namespace +{ + +using namespace luci::test; + +class ReshapeNegGraphlet +{ +public: + ReshapeNegGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeU32 shape_out) + { + std::vector<uint32_t> shape_out_v = shape_out; + + _reshape_shape = g->nodes()->create<luci::CircleConst>(); + _reshape = g->nodes()->create<luci::CircleReshape>(); + _neg = g->nodes()->create<luci::CircleNeg>(); + + _reshape_shape->dtype(loco::DataType::S32); + _reshape_shape->rank(1); + _reshape_shape->dim(0).set(shape_out_v.size()); + _reshape_shape->shape_status(luci::ShapeStatus::VALID); + // values + const auto size = shape_out_v.size(); + _reshape_shape->size<loco::DataType::S32>(size); + for (uint32_t i = 0; i < size; i++) + _reshape_shape->at<loco::DataType::S32>(i) = shape_out_v[i]; + + _reshape_shape->name("reshape_shape"); + _reshape->name("reshape"); + _neg->name("neg"); + } + +protected: + luci::CircleReshape *_reshape = nullptr; + luci::CircleNeg *_neg = nullptr; + luci::CircleConst *_reshape_shape = nullptr; +}; + +class ForwardReshapeToNegGraph : public TestIOGraph, public ReshapeNegGraphlet +{ +public: + ForwardReshapeToNegGraph() = default; + +public: + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + ReshapeNegGraphlet::init(g(), shape_in, shape_out); + + // connect network + _reshape->tensor(input()); + _reshape->shape(_reshape_shape); + _neg->x(_reshape); + + output()->from(_neg); + } +}; + +class ForwardReshapeToNegGraphTest : public ::testing::Test +{ +public: + ForwardReshapeToNegGraphTest() = default; + + void run_pass(void) + { + while (_pass.run(_graph.g())) + ; + } + +protected: + ForwardReshapeToNegGraph _graph; + luci::ForwardReshapeToUnaryOpPass _pass; +}; + +} // namespace + +TEST(ForwardReshapeToUnaryOpPassTest, name) +{ + luci::ForwardReshapeToUnaryOpPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(ForwardReshapeToNegGraphTest, simple_forward) +{ + _graph.init({2, 2, 2}, {2, 4}); + + run_pass(); + + auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from()); + auto neg = dynamic_cast<luci::CircleNeg *>(_graph.output()->from()); + ASSERT_NE(nullptr, reshape); + ASSERT_EQ(nullptr, neg); + neg = dynamic_cast<luci::CircleNeg *>(reshape->tensor()); + ASSERT_NE(nullptr, neg); +} diff --git a/compiler/luci/pass/src/FuseActivationFunctionPass.cpp b/compiler/luci/pass/src/FuseActivationFunctionPass.cpp index 844541d2d..66e341518 100644 --- a/compiler/luci/pass/src/FuseActivationFunctionPass.cpp +++ b/compiler/luci/pass/src/FuseActivationFunctionPass.cpp @@ -17,7 +17,9 @@ #include "luci/Pass/FuseActivationFunctionPass.h" #include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeMixins.h> #include <luci/IR/CircleOpcode.h> +#include <luci/Profile/CircleNodeOrigin.h> namespace luci { @@ -32,10 +34,15 @@ bool fuse_activation_function(luci::CircleNode *node) return false; auto node_with_fused_act = - dynamic_cast<luci::LuciNodeMixin<luci::LuciNodeTrait::FusedActFunc> *>(pred_node); + dynamic_cast<luci::CircleNodeMixin<luci::CircleNodeTrait::FusedActFunc> *>(pred_node); if (node_with_fused_act == nullptr) return false; + // TODO remove this work-around + // This will skip fuse for concat as luci-interpreter doesn't support this yet + if (dynamic_cast<luci::CircleConcatenation *>(pred_node) != nullptr) + return false; + auto fused_act = node_with_fused_act->fusedActivationFunction(); luci::FusedActFunc target_func = luci::FusedActFunc::UNDEFINED; @@ -76,6 +83,7 @@ bool fuse_activation_function(luci::CircleNode *node) return false; node_with_fused_act->fusedActivationFunction(target_func); + luci::add_origin(pred_node, luci::get_origin(node)); loco::replace(node).with(pred_node); node->drop(); diff --git a/compiler/luci/pass/src/FuseActivationFunctionPass.test.cpp b/compiler/luci/pass/src/FuseActivationFunctionPass.test.cpp index 226a303a1..56b414143 100644 --- a/compiler/luci/pass/src/FuseActivationFunctionPass.test.cpp +++ b/compiler/luci/pass/src/FuseActivationFunctionPass.test.cpp @@ -14,15 +14,19 @@ * limitations under the License. */ -#include "FuseActivationFunctionPassInternal.h" +#include "luci/Pass/FuseActivationFunctionPass.h" #include <luci/IR/CircleNodes.h> +#include <luci/test/TestIOGraph.h> + #include <gtest/gtest.h> namespace { +using namespace luci::test; + /** * Simple graph for test * @@ -41,60 +45,148 @@ namespace * [Conv2] * */ -class SimpleGraph +class ConvReluConvGraphlet +{ +public: + ConvReluConvGraphlet() = default; + + void init(loco::Graph *g) + { + _conv1 = g->nodes()->create<luci::CircleConv2D>(); + _conv2 = g->nodes()->create<luci::CircleConv2D>(); + _relu = g->nodes()->create<luci::CircleRelu>(); + _conv1_f = g->nodes()->create<luci::CircleConst>(); + _conv1_b = g->nodes()->create<luci::CircleConst>(); + _conv2_f = g->nodes()->create<luci::CircleConst>(); + _conv2_b = g->nodes()->create<luci::CircleConst>(); + + _conv1->fusedActivationFunction(luci::FusedActFunc::NONE); + + _conv1->name("conv1"); + _conv2->name("conv2"); + _relu->name("relu"); + _conv1_f->name("conv1f"); + _conv1_b->name("conv1b"); + _conv2_f->name("conv2f"); + _conv2_b->name("conv2b"); + } + +public: + luci::CircleRelu *relu() { return _relu; } + luci::CircleConv2D *conv1() { return _conv1; } + luci::CircleConv2D *conv2() { return _conv2; } + +protected: + luci::CircleConv2D *_conv1 = nullptr; + luci::CircleConv2D *_conv2 = nullptr; + luci::CircleRelu *_relu = nullptr; + luci::CircleConst *_conv1_f = nullptr; + luci::CircleConst *_conv1_b = nullptr; + luci::CircleConst *_conv2_f = nullptr; + luci::CircleConst *_conv2_b = nullptr; +}; + +class FuseActTestGraph : public TestIOGraph, public ConvReluConvGraphlet { public: - SimpleGraph() + FuseActTestGraph() = default; + + void init(void) { - conv1 = g.nodes()->create<luci::CircleConv2D>(); - conv2 = g.nodes()->create<luci::CircleConv2D>(); - relu = g.nodes()->create<luci::CircleRelu>(); + TestIOGraph::init({1}, {1}); + ConvReluConvGraphlet::init(g()); - conv1->fusedActivationFunction(luci::FusedActFunc::NONE); + _conv1->input(input()); + _conv1->filter(_conv1_f); + _conv1->bias(_conv1_b); - relu->features(conv1); - conv2->input(relu); + _relu->features(_conv1); + + _conv2->input(_relu); + _conv2->filter(_conv2_f); + _conv2->bias(_conv2_b); + + output()->from(_conv2); } +}; +class ConvHasMultiSuccGraph : public TestIOGraph, public ConvReluConvGraphlet +{ public: - loco::Graph g; - luci::CircleConv2D *conv1; - luci::CircleConv2D *conv2; - luci::CircleRelu *relu; + ConvHasMultiSuccGraph() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + ConvReluConvGraphlet::init(g()); + + _conv1->input(input()); + _conv1->filter(_conv1_f); + _conv1->bias(_conv1_b); + + _relu->features(_conv1); + + _conv2->input(_conv1); + _conv2->filter(_conv2_f); + _conv2->bias(_conv2_b); + + output()->from(_relu); // We need to check from relu + } }; +// TODO use ::testing::Test + } // namespace +TEST(FuseActivationFunctionPassTest, name) +{ + luci::FuseActivationFunctionPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + TEST(FusePreActivationBatchNorm, fuse_activation_function) { - SimpleGraph g; + FuseActTestGraph g; + luci::FuseActivationFunctionPass pass; - EXPECT_TRUE(luci::fuse_activation_function(g.relu)); + g.init(); - EXPECT_EQ(g.conv1, g.conv2->input()); + EXPECT_TRUE(pass.run(g.g())); + EXPECT_EQ(g.conv1(), g.conv2()->input()); } TEST(FusePreActivationBatchNorm, fuse_activation_function_dup_relu) { - SimpleGraph g; - g.conv1->fusedActivationFunction(luci::FusedActFunc::RELU); + FuseActTestGraph g; + luci::FuseActivationFunctionPass pass; - EXPECT_TRUE(luci::fuse_activation_function(g.relu)); + g.init(); + g.conv1()->fusedActivationFunction(luci::FusedActFunc::RELU); - EXPECT_EQ(g.conv1, g.conv2->input()); + EXPECT_TRUE(pass.run(g.g())); + EXPECT_EQ(g.conv1(), g.conv2()->input()); } -TEST(FusePreActivationBatchNorm, fuse_activation_function_NEG) +TEST(FusePreActivationBatchNorm, fuse_activation_function_mulsucc_NEG) { - SimpleGraph g; - g.conv2->input(g.conv1); + ConvHasMultiSuccGraph g; + luci::FuseActivationFunctionPass pass; + + g.init(); - // Conv1 has multiple successors - EXPECT_FALSE(luci::fuse_activation_function(g.relu)); + // Relu input Conv2D has multiple successors + EXPECT_FALSE(pass.run(g.g())); +} + +TEST(FusePreActivationBatchNorm, fuse_activation_function_tanh_NEG) +{ + FuseActTestGraph g; + luci::FuseActivationFunctionPass pass; - g.conv2->input(g.relu); - g.conv1->fusedActivationFunction(luci::FusedActFunc::TANH); + g.init(); + g.conv1()->fusedActivationFunction(luci::FusedActFunc::TANH); - // Conv1 already has activation function - EXPECT_FALSE(luci::fuse_activation_function(g.relu)); + // Relu input Conv2D already has activation function + EXPECT_FALSE(pass.run(g.g())); } diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp index bd7805f6a..2bca57014 100644 --- a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp +++ b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp @@ -17,20 +17,30 @@ #include "luci/Pass/FuseAddWithTConvPass.h" #include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> namespace { /** - * Fuse add to TCONV if possible + * Fuse Add to TransposeConv if possible * * BEFORE - * - * [CircleTransposeConv] + * | + * [CircleConst] [CircleTransposeConv] + * \ | + * [CircleAdd] * | - * [add] + * * AFTER + * | + * [CircleConst] | + * \ | + * [CircleTransposeConv] [CircleAdd] + * | + * ([CircleRelu6]) + * | * - * [CircleTransposeConv] + * Note: CircleRelu6 is inserted if Add activation is ReLU6 */ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv) { @@ -81,9 +91,13 @@ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv) if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6) { + auto name = addition->name(); + assert(name.length() > 0); // separate relu op from add op auto relu = add->graph()->nodes()->create<luci::CircleRelu6>(); relu->features(tconv); + relu->name(name + "/Relu6"); + luci::add_origin(relu, luci::get_origin(add)); // remove add node replace(add).with(relu); @@ -93,6 +107,9 @@ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv) replace(add).with(tconv); } + // set origin + luci::add_origin(tconv, luci::get_origin(add)); + return true; } diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp new file mode 100644 index 000000000..8748d73ef --- /dev/null +++ b/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseAddWithTConvPass.h" + +#include <gtest/gtest.h> + +TEST(FuseAddWithTConvPassTest, name) +{ + luci::FuseAddWithTConvPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/FuseBCQPass.cpp b/compiler/luci/pass/src/FuseBCQPass.cpp index c0583d848..09180d8c1 100644 --- a/compiler/luci/pass/src/FuseBCQPass.cpp +++ b/compiler/luci/pass/src/FuseBCQPass.cpp @@ -17,6 +17,7 @@ #include "luci/Pass/FuseBCQPass.h" #include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> #include <luci/Log.h> #include <cassert> @@ -111,7 +112,7 @@ template <> class BCQFuser<1> { public: BCQFuser<1>(int32_t original_output_cnt, int32_t bundle_cnt) - : _original_output_cnt{original_output_cnt}, _bundle_cnt{bundle_cnt} + : _original_output_cnt{original_output_cnt}, _bundle_cnt{bundle_cnt} { // Do nothing } @@ -133,7 +134,7 @@ public: { const auto prefix = (output_node->index() - (_original_output_cnt + 1)) / (_bundle_cnt); const MetadataType metadata_type = static_cast<MetadataType>( - (output_node->index() - (_original_output_cnt + 1)) % (_bundle_cnt)); + (output_node->index() - (_original_output_cnt + 1)) % (_bundle_cnt)); const auto circle_node = loco::must_cast<luci::CircleNode *>(output_node->from()); add_BCQ_info_node(prefix, metadata_type, circle_node); } @@ -156,13 +157,18 @@ public: if (prefix == -1 || !is_valid_prefix(prefix)) continue; + auto name = gather->name(); + assert(name.length() > 0); + auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>(); + luci::add_origin(bcq_gather, luci::get_origin(gather)); bcq_gather->op_version(1); bcq_gather->input_scales(alpha(g, prefix)); bcq_gather->input_binary(packed_binary_code(g, prefix)); bcq_gather->indices(gather->indices()); bcq_gather->input_clusters(packed_clusters(g, prefix)); + bcq_gather->name(name + "/BCQGather"); if (_do_w_x[prefix]->at<loco::DataType::BOOL>(0)) { @@ -177,7 +183,7 @@ public: bcq_gather->axis(axis_transpose); const auto indices_rank = - loco::must_cast<luci::CircleNode *>(gather->indices())->rank(); + loco::must_cast<luci::CircleNode *>(gather->indices())->rank(); auto perm = g->nodes()->create<luci::CircleConst>(); perm->dtype(loco::DataType::S32); @@ -188,10 +194,13 @@ public: perm->at<loco::DataType::S32>(idx) = idx + 1; perm->at<loco::DataType::S32>(indices_rank) = 0; perm->shape_status(luci::ShapeStatus::VALID); + perm->name(name + "/Transpose/perm"); auto output_transpose = g->nodes()->create<luci::CircleTranspose>(); + luci::add_origin(output_transpose, luci::get_origin(gather)); output_transpose->a(bcq_gather); output_transpose->perm(perm); + output_transpose->name(name + "/Transpose"); loco::replace(gather).with(output_transpose); } @@ -209,7 +218,11 @@ public: if (prefix == -1 || !is_valid_prefix(prefix)) continue; + auto name = fully_connected->name(); + assert(name.length() > 0); + auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>(); + luci::add_origin(bcq_fc, luci::get_origin(fully_connected)); bcq_fc->op_version(1); bcq_fc->weights_scales(alpha(g, prefix)); @@ -217,6 +230,7 @@ public: bcq_fc->bias(fully_connected->bias()); bcq_fc->weights_clusters(packed_clusters(g, prefix)); bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction()); + bcq_fc->name(name + "/BCQFullyConnected"); loco::Node *bcq_input = fully_connected->input(); @@ -231,18 +245,16 @@ public: new_shape->rank(1); new_shape->dim(0) = 2; - auto batch_size = 1; - for (uint32_t i = 0; i < original_input->rank() - 1; ++i) - batch_size *= original_input->dim(i).value(); - - new_shape->at<loco::DataType::S32>(0) = batch_size; - new_shape->at<loco::DataType::S32>(1) = - original_input->dim(original_input->rank() - 1).value(); + new_shape->at<loco::DataType::S32>(0) = -1; + new_shape->at<loco::DataType::S32>(1) = weights->dim(1).value(); new_shape->shape_status(luci::ShapeStatus::VALID); + new_shape->name(name + "/Reshape/shape"); auto reshape = g->nodes()->create<luci::CircleReshape>(); + luci::add_origin(reshape, luci::get_origin(fully_connected)); reshape->tensor(original_input); reshape->shape(new_shape); + reshape->name(name + "/Reshape"); bcq_input = reshape; } @@ -258,23 +270,28 @@ public: perm->at<loco::DataType::S32>(0) = 1; perm->at<loco::DataType::S32>(1) = 0; perm->shape_status(luci::ShapeStatus::VALID); + perm->name(name + "/Transpose/perm"); auto input_transpose = g->nodes()->create<luci::CircleTranspose>(); + luci::add_origin(input_transpose, luci::get_origin(fully_connected)); input_transpose->a(bcq_input); input_transpose->perm(perm); + input_transpose->name(name + "_input/Transpose"); bcq_fc->input(input_transpose); auto output_transpose = g->nodes()->create<luci::CircleTranspose>(); + luci::add_origin(output_transpose, luci::get_origin(fully_connected)); output_transpose->a(bcq_fc); output_transpose->perm(perm); + output_transpose->name(name + "_output/Transpose"); loco::replace(fully_connected).with(output_transpose); return true; } else if (auto weights_as_input = - dynamic_cast<luci::CircleConst *>(fully_connected->input())) + dynamic_cast<luci::CircleConst *>(fully_connected->input())) { auto prefix = get_prefix_of_const(weights_as_input); if (prefix == -1 || !is_valid_prefix(prefix)) @@ -282,6 +299,9 @@ public: assert(_do_w_x[prefix]->at<loco::DataType::BOOL>(0) == true); + auto name = weights_as_input->name(); + assert(name.length() > 0); + auto perm = g->nodes()->create<luci::CircleConst>(); perm->dtype(loco::DataType::S32); perm->size<loco::DataType::S32>(2); @@ -290,12 +310,16 @@ public: perm->at<loco::DataType::S32>(0) = 1; perm->at<loco::DataType::S32>(1) = 0; perm->shape_status(luci::ShapeStatus::VALID); + perm->name(name + "/Transpose/perm"); auto input_transpose = g->nodes()->create<luci::CircleTranspose>(); + luci::add_origin(input_transpose, luci::get_origin(fully_connected)); input_transpose->a(fully_connected->weights()); input_transpose->perm(perm); + input_transpose->name(name + "/Transpose"); auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>(); + luci::add_origin(bcq_fc, luci::get_origin(fully_connected)); assert(dynamic_cast<luci::CircleOutputExclude *>(fully_connected->bias()) != nullptr); @@ -308,6 +332,8 @@ public: bcq_fc->weights_hidden_size(weights_as_input->dim(1).value()); bcq_fc->input(input_transpose); + bcq_fc->name(name + "/BCQFullyConnected"); + loco::replace(fully_connected).with(bcq_fc); return true; @@ -533,7 +559,7 @@ private: new_beta->dim(1) = _packed_binary_code[prefix]->dim(1); for (uint32_t i = 0; i < _packed_binary_code[prefix]->size<loco::DataType::S32>(); ++i) new_beta->at<loco::DataType::S32>(i) = - _packed_binary_code[prefix]->at<loco::DataType::S32>(i); + _packed_binary_code[prefix]->at<loco::DataType::S32>(i); new_beta->shape_status(luci::ShapeStatus::VALID); return new_beta; @@ -556,9 +582,9 @@ private: for (int i = 0; i < number_of_clusters; ++i) { packed_clusters->at<loco::DataType::S32>(i * 2) = - qbits_of_clusters->at<loco::DataType::S32>(i); + qbits_of_clusters->at<loco::DataType::S32>(i); packed_clusters->at<loco::DataType::S32>(i * 2 + 1) = - size_of_clusters->at<loco::DataType::S32>(i); + size_of_clusters->at<loco::DataType::S32>(i); } return packed_clusters; diff --git a/compiler/luci/pass/src/FuseBCQPass.test.cpp b/compiler/luci/pass/src/FuseBCQPass.test.cpp new file mode 100644 index 000000000..73677affd --- /dev/null +++ b/compiler/luci/pass/src/FuseBCQPass.test.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseBCQPass.h" + +#include <gtest/gtest.h> + +TEST(FuseBCQPassTest, name) +{ + luci::FuseBCQPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/FuseBatchNormWithConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithConvPass.cpp new file mode 100644 index 000000000..062da7058 --- /dev/null +++ b/compiler/luci/pass/src/FuseBatchNormWithConvPass.cpp @@ -0,0 +1,232 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseBatchNormWithConvPass.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> + +namespace +{ +/** + * Fuse Mul-Add to Conv2D if possible. + * + * NOTE TF's BatchNormalization is converted to Mul and Add. + * + * BEFORE + * | [CircleConst] + * | / [CircleConst] + * | / / + * [CircleConv2D] [CircleConst] + * | / + * [CircleMul] [CircleConst] + * | / + * [CircleAdd] + * | + * + * AFTER + * | [CircleConst] + * +--------------+ / [CircleConst] + * | | / / + * | [CircleConv2D] [CircleConst] + * [CircleConst] | | / + * [CircleConst] \ | [CircleMul] [CircleConst] + * \ \ | | / + * [CircleConv2D] [CircleAdd] + * | + */ +bool fused_batch_norm_with_conv(luci::CircleAdd *add) +{ + luci::CircleMul *mul = nullptr; + luci::CircleConst *shift = nullptr; + if (auto add_lhs = dynamic_cast<luci::CircleMul *>(add->x())) + { + mul = add_lhs; + shift = dynamic_cast<luci::CircleConst *>(add->y()); + } + else if (auto add_rhs = dynamic_cast<luci::CircleMul *>(add->y())) + { + mul = add_rhs; + shift = dynamic_cast<luci::CircleConst *>(add->x()); + } + + // If CircleMul is not found or constant operand of CircleAdd is not found, + // this pass cannot be applied. + if (mul == nullptr || shift == nullptr) + return false; + + // If FusedActivationFunction of mul is not none, this pass cannot be applied. + if (mul->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + // To apply this pass, shape of shift should be [1, 1, 1, out_channel]. + if (shift->rank() != 4) + return false; + for (uint32_t i = 0; i < 3; ++i) + if (shift->dim(i).value() != 1) + return false; + + luci::CircleConv2D *conv = nullptr; + luci::CircleConst *scale = nullptr; + if (auto mul_lhs = dynamic_cast<luci::CircleConv2D *>(mul->x())) + { + conv = mul_lhs; + scale = dynamic_cast<luci::CircleConst *>(mul->y()); + } + else if (auto mul_rhs = dynamic_cast<luci::CircleConv2D *>(mul->y())) + { + conv = mul_rhs; + scale = dynamic_cast<luci::CircleConst *>(mul->x()); + } + + // If CircleConv2D is not found or constant operand of CircleMul is not found, + // this pass cannot be applied. + if (conv == nullptr || scale == nullptr) + return false; + + // To apply this pass, shape of scale should be [1, 1, 1, out_channel]. + if (scale->rank() != 4) + return false; + for (uint32_t i = 0; i < 3; ++i) + if (scale->dim(i).value() != 1) + return false; + + // If FusedActivationFunction of conv is not none, this pass cannot be applied. + if (conv->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + luci::CircleConst *filter = dynamic_cast<luci::CircleConst *>(conv->filter()); + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(conv->bias()); + + // If filter or bias of conv is not const, this pass cannot be applied. + if (filter == nullptr || bias == nullptr) + return false; + + // If dtype of filter is different with scale and shift, multiplication may be impossible. + if (filter->dtype() != scale->dtype()) + return false; + if (filter->dtype() != shift->dtype()) + return false; + + // TODO Support more data type + if (filter->dtype() != loco::DataType::FLOAT32) + return false; + + // Output channel dimension should be same. If not, this pass cannot be applied. + if (filter->dim(0).value() != scale->dim(3).value()) + return false; + if (filter->dim(0).value() != shift->dim(3).value()) + return false; + + auto name = add->name(); + assert(name.length() > 0); + + luci::CircleConv2D *fused_conv = add->graph()->nodes()->create<luci::CircleConv2D>(); + luci::CircleConst *fused_filter = add->graph()->nodes()->create<luci::CircleConst>(); + luci::CircleConst *fused_bias = add->graph()->nodes()->create<luci::CircleConst>(); + + uint32_t filter_out_channel = filter->dim(0).value(); + uint32_t filter_height = filter->dim(1).value(); + uint32_t filter_width = filter->dim(2).value(); + uint32_t filter_in_channel = filter->dim(3).value(); + + // Copy filter + fused_filter->dtype(filter->dtype()); + fused_filter->size<loco::DataType::FLOAT32>(filter->size<loco::DataType::FLOAT32>()); + fused_filter->rank(4); + fused_filter->dim(0).set(filter_out_channel); + fused_filter->dim(1).set(filter_height); + fused_filter->dim(2).set(filter_width); + fused_filter->dim(3).set(filter_in_channel); + fused_filter->shape_status(luci::ShapeStatus::VALID); + fused_filter->name(name + "/Conv2D/filter"); + + // Fuse scale to new filter + for (uint32_t c = 0; c < filter_out_channel; c++) + { + for (uint32_t h = 0; h < filter_height; h++) + { + for (uint32_t w = 0; w < filter_width; w++) + { + for (uint32_t b = 0; b < filter_in_channel; b++) + { + uint32_t offset = c * filter_height * filter_width * filter_in_channel + + h * filter_width * filter_in_channel + w * filter_in_channel + b; + fused_filter->at<loco::DataType::FLOAT32>(offset) = + filter->at<loco::DataType::FLOAT32>(offset) * scale->at<loco::DataType::FLOAT32>(c); + } + } + } + } + + // Copy bias + assert(bias->rank() == 1); + assert(bias->dim(0).value() == filter_out_channel); + fused_bias->dtype(bias->dtype()); + fused_bias->size<loco::DataType::FLOAT32>(bias->size<loco::DataType::FLOAT32>()); + fused_bias->rank(1); + fused_bias->dim(0).set(filter_out_channel); + fused_bias->shape_status(luci::ShapeStatus::VALID); + fused_bias->name(name + "/Conv2D/bias"); + + // Fuse scale and shift to bias + for (uint32_t b = 0; b < filter_out_channel; ++b) + { + fused_bias->at<loco::DataType::FLOAT32>(b) = + bias->at<loco::DataType::FLOAT32>(b) * scale->at<loco::DataType::FLOAT32>(b) + + shift->at<loco::DataType::FLOAT32>(b); + } + + // Set attributes of fused_conv + fused_conv->input(conv->input()); + fused_conv->filter(fused_filter); + fused_conv->bias(fused_bias); + fused_conv->fusedActivationFunction(add->fusedActivationFunction()); + fused_conv->padding(conv->padding()); + fused_conv->stride()->h(conv->stride()->h()); + fused_conv->stride()->w(conv->stride()->w()); + fused_conv->dilation()->h(conv->dilation()->h()); + fused_conv->dilation()->w(conv->dilation()->w()); + fused_conv->name(name + "/Conv2D"); + luci::add_origin(fused_conv, luci::composite_origin({luci::get_origin(add), luci::get_origin(mul), + luci::get_origin(conv)})); + + replace(add).with(fused_conv); + + return true; +} + +} // namespace + +namespace luci +{ + +bool FuseBatchNormWithConvPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto add = dynamic_cast<luci::CircleAdd *>(node)) + { + if (fused_batch_norm_with_conv(add)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseBatchNormWithConvPass.test.cpp b/compiler/luci/pass/src/FuseBatchNormWithConvPass.test.cpp new file mode 100644 index 000000000..96bc2bd35 --- /dev/null +++ b/compiler/luci/pass/src/FuseBatchNormWithConvPass.test.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseBatchNormWithConvPass.h" + +#include <gtest/gtest.h> + +TEST(FuseBatchNormWithConvPassTest, name) +{ + luci::FuseBatchNormWithConvPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/FuseBatchNormWithDwConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithDwConvPass.cpp new file mode 100644 index 000000000..8b2286f43 --- /dev/null +++ b/compiler/luci/pass/src/FuseBatchNormWithDwConvPass.cpp @@ -0,0 +1,237 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseBatchNormWithDwConvPass.h" + +#include "helpers/NodeFiller.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> + +namespace +{ +/** + * Fuse Mul-Add to DepthwiseConv2D if possible. + * + * NOTE TF's BatchNormalization is converted to Mul and Add. + * + * BEFORE + * | [CircleConst] + * | / [CircleConst] + * | / / + * [CircleDepthwiseConv2D] [CircleConst] + * | / + * [CircleMul] [CircleConst] + * | / + * [CircleAdd] + * | + * + * AFTER + * | [CircleConst] + * +-------------------------------------+ / [CircleConst] + * | | / / + * | [CircleDepthwiseConv2D] [CircleConst] + * | [CircleConst] | / + * | / [CircleConst] [CircleMul] [CircleConst] + * | / / | / + * [CircleDepthwiseConv2D] [CircleAdd] + * | + * + */ + +/** + * @brief Check shape is [x] or [1, 1, 1, x] + */ +bool is_scale_shift_shape(luci::CircleConst *node) +{ + auto rank = node->rank(); + if (rank != 1 && rank != 4) + return false; + for (uint32_t r = 0; r < rank - 1; ++r) + { + if (node->dim(r).value() != 1) + return false; + } + return true; +} + +bool fused_batch_norm_with_dwconv(luci::CircleAdd *add) +{ + assert(add != nullptr); + + // Find the pattern of CircleDepthwiseConv2D - CircleMul - CircleAdd + luci::CircleConst *scale = nullptr; + luci::CircleConst *shift = nullptr; + luci::CircleDepthwiseConv2D *dwconv = nullptr; + luci::CircleMul *mul = nullptr; + if (not luci::fill(&shift, &mul).with_commutative_args_of(add)) + return false; + if (not luci::fill(&scale, &dwconv).with_commutative_args_of(mul)) + return false; + + // check scale and shift constant attributes + // scale and shift can be [x] or [1, 1, 1, x] + if (not is_scale_shift_shape(scale)) + return false; + if (not is_scale_shift_shape(shift)) + return false; + + // check mul, add attributes + if (mul->dtype() != loco::DataType::FLOAT32) + return false; + if (mul->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + if (add->dtype() != loco::DataType::FLOAT32) + return false; + // TODO support more Activations + if (add->fusedActivationFunction() != luci::FusedActFunc::NONE && + add->fusedActivationFunction() != luci::FusedActFunc::RELU6) + return false; + + // get weight of dwconv + auto filter = dynamic_cast<luci::CircleConst *>(dwconv->filter()); + if (not filter) + return false; + if (filter->dtype() != loco::DataType::FLOAT32) + return false; + if (filter->rank() != 4) + return false; + + // check attributes of dwconv + if (dwconv->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + if (dwconv->depthMultiplier() < 0) // can this happen? + return false; + + // get bias of dwconv + auto bias = dynamic_cast<luci::CircleConst *>(dwconv->bias()); + if (not bias) + return false; + if (bias->dtype() != loco::DataType::FLOAT32) + return false; + if (bias->rank() != 1) + return false; + + // filter represents as [1, H, W, C*M] where M is multiplier. + auto filter_out_chn = filter->dim(3).value(); + auto multiplier = static_cast<uint32_t>(dwconv->depthMultiplier()); + auto srank = scale->rank(); // as rank can be 1 or 4 + if (filter_out_chn != scale->dim(srank - 1).value() * multiplier) + return false; + srank = shift->rank(); + if (filter_out_chn != shift->dim(srank - 1).value() * multiplier) + return false; + auto channel = filter_out_chn / multiplier; + + auto name = add->name(); + assert(name.length() > 0); + + loco::Graph *graph = add->graph(); + luci::CircleDepthwiseConv2D *fused_dwconv = graph->nodes()->create<luci::CircleDepthwiseConv2D>(); + luci::CircleConst *fused_filter = graph->nodes()->create<luci::CircleConst>(); + luci::CircleConst *fused_bias = graph->nodes()->create<luci::CircleConst>(); + + auto filter_in_chn = filter->dim(0).value(); + auto filter_height = filter->dim(1).value(); + auto filter_width = filter->dim(2).value(); + assert(filter_in_chn == 1); + + // Copy filter shape + fused_filter->dtype(filter->dtype()); + fused_filter->size<loco::DataType::FLOAT32>(filter->size<loco::DataType::FLOAT32>()); + fused_filter->rank(4); + fused_filter->dim(0).set(filter_in_chn); + fused_filter->dim(1).set(filter_height); + fused_filter->dim(2).set(filter_width); + fused_filter->dim(3).set(filter_out_chn); + fused_filter->shape_status(luci::ShapeStatus::VALID); + fused_filter->name(name + "/DepthwiseConv2D/filter"); + + // fused filter weight = filter weight * mul(scale) + add(shift) + for (uint32_t b = 0; b < filter_in_chn; b++) + { + for (uint32_t h = 0; h < filter_height; h++) + { + for (uint32_t w = 0; w < filter_width; w++) + { + for (uint32_t c = 0; c < filter_out_chn; c++) + { + uint32_t offset = b * filter_height * filter_width * filter_out_chn + + h * filter_width * filter_out_chn + w * filter_out_chn + c; + uint32_t chn = c / multiplier; + fused_filter->at<loco::DataType::FLOAT32>(offset) = + filter->at<loco::DataType::FLOAT32>(offset) * scale->at<loco::DataType::FLOAT32>(chn); + } + } + } + } + + // Fuse bias with scale and shift + fused_bias->dtype(shift->dtype()); + fused_bias->size<loco::DataType::FLOAT32>(shift->size<loco::DataType::FLOAT32>()); + fused_bias->rank(1); + fused_bias->dim(0).set(channel); + fused_bias->shape_status(luci::ShapeStatus::VALID); + for (uint32_t c = 0; c < channel; ++c) + { + fused_bias->at<loco::DataType::FLOAT32>(c) = + bias->at<loco::DataType::FLOAT32>(c) * scale->at<loco::DataType::FLOAT32>(c) + + shift->at<loco::DataType::FLOAT32>(c); + } + fused_bias->name(name + "/DepthwiseConv2D/bias"); + + // set new tconv properties + fused_dwconv->input(dwconv->input()); + fused_dwconv->filter(fused_filter); + fused_dwconv->bias(fused_bias); + fused_dwconv->fusedActivationFunction(add->fusedActivationFunction()); + fused_dwconv->padding(dwconv->padding()); + fused_dwconv->stride()->h(dwconv->stride()->h()); + fused_dwconv->stride()->w(dwconv->stride()->w()); + fused_dwconv->depthMultiplier(dwconv->depthMultiplier()); + fused_dwconv->dilation()->h(dwconv->dilation()->h()); + fused_dwconv->dilation()->w(dwconv->dilation()->w()); + fused_dwconv->name(name + "/DepthwiseConv2D"); + luci::add_origin(fused_dwconv, + luci::composite_origin( + {luci::get_origin(add), luci::get_origin(mul), luci::get_origin(dwconv)})); + + replace(add).with(fused_dwconv); + + return true; +} + +} // namespace + +namespace luci +{ + +bool FuseBatchNormWithDwConvPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto add = dynamic_cast<luci::CircleAdd *>(node)) + { + if (fused_batch_norm_with_dwconv(add)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseBatchNormWithDwConvPass.test.cpp b/compiler/luci/pass/src/FuseBatchNormWithDwConvPass.test.cpp new file mode 100644 index 000000000..3030a7306 --- /dev/null +++ b/compiler/luci/pass/src/FuseBatchNormWithDwConvPass.test.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseBatchNormWithDwConvPass.h" + +#include <gtest/gtest.h> + +TEST(FuseBatchNormWithDwConvPassTest, name) +{ + luci::FuseBatchNormWithDwConvPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp deleted file mode 100644 index 95ccd8176..000000000 --- a/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "luci/Pass/FuseBatchNormWithTConv.h" - -#include <luci/IR/CircleNodes.h> - -namespace -{ -/** - * NOTE TF's fusedBatchNorm is converted to mul and add of Circle. - * - * BEFORE - * - * [CircleTransposeConv] - * | - * [mul] - * | - * [add] - * AFTER - * - * [CircleTransposeConv] - */ -bool fused_batch_norm_with_tconv(luci::CircleTransposeConv *tconv) -{ - // check whether it has bias or not. This optimization works only if it doesn't. - auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias()); - if (not bias) - return false; - - // get weight of tconv - auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter()); - if (not filter) - return false; - if (filter->dtype() != loco::DataType::FLOAT32) - return false; - - // get mul node - auto tconv_output = loco::succs(tconv); - assert(tconv_output.size() == 1); - auto mul = dynamic_cast<luci::CircleMul *>(*tconv_output.begin()); - if (not mul) - return false; - if (mul->dtype() != loco::DataType::FLOAT32) - return false; - - // get add node - auto mul_output = loco::succs(mul); - assert(mul_output.size() == 1); - auto add = dynamic_cast<luci::CircleAdd *>(*mul_output.begin()); - if (not add) - return false; - if (add->dtype() != loco::DataType::FLOAT32) - return false; - if (add->fusedActivationFunction() != luci::FusedActFunc::NONE && - add->fusedActivationFunction() != luci::FusedActFunc::RELU6) - return false; - - // get scale of batchnorm - auto scale = dynamic_cast<luci::CircleConst *>(mul->y()); - if (not scale) - return false; - - // scale dim(0) == tconv filter channel dim - if (filter->rank() != 4) - return false; - auto filter_out_dim = filter->dim(0).value(); - if (scale->rank() != 1) - return false; - auto scale_dim = scale->dim(0).value(); - if (filter_out_dim != scale_dim) - return false; - - // get shift of batchnorm - auto shift = dynamic_cast<luci::CircleConst *>(add->y()); - if (not shift) - return false; - - // shift dim(0) == tconv filter channel dim - if (shift->rank() != 1) - return false; - auto shift_dim = shift->dim(0).value(); - if (filter_out_dim != shift_dim) - return false; - - // filter weight = filter weight * mul(scale) + add(shift) - uint32_t filter_height_dim = filter->dim(1).value(); - uint32_t filter_width_dim = filter->dim(2).value(); - uint32_t filter_in_dim = filter->dim(3).value(); - for (uint32_t c = 0; c < filter_out_dim; c++) - { - for (uint32_t h = 0; h < filter_height_dim; h++) - { - for (uint32_t w = 0; w < filter_width_dim; w++) - { - for (uint32_t b = 0; b < filter_in_dim; b++) - { - uint32_t offset = c * filter_height_dim * filter_width_dim * filter_in_dim + - h * filter_width_dim * filter_in_dim + w * filter_in_dim + b; - filter->at<loco::DataType::FLOAT32>(offset) *= scale->at<loco::DataType::FLOAT32>(c); - } - } - } - } - - // fuse shift with transposed conv - tconv->bias(shift); - - if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6) - { - // separate relu op from add op - auto relu = add->graph()->nodes()->create<luci::CircleRelu6>(); - relu->features(tconv); - - // remove mul node - replace(add).with(relu); - } - else - { - replace(add).with(tconv); - } - - return true; -} - -} // namespace - -namespace luci -{ - -bool FuseBatchNormWithTConvPass::run(loco::Graph *g) -{ - bool changed = false; - for (auto node : loco::active_nodes(loco::output_nodes(g))) - { - auto tconv = dynamic_cast<luci::CircleTransposeConv *>(node); - if (not tconv) - continue; - - changed |= fused_batch_norm_with_tconv(tconv); - } - - return changed; -} - -} // namespace luci diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp new file mode 100644 index 000000000..337954960 --- /dev/null +++ b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseBatchNormWithTConvPass.h" + +#include "helpers/NodeFiller.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> + +namespace +{ +/** + * Fuse Mul-Add to TransposeConv if possible. + * + * NOTE TF's BatchNormalization is converted to Mul and Add. + * + * BEFORE + * | [CircleOutputExclude] + * | / [CircleConst] + * | / / + * [CircleTransposeConv] [CircleConst] + * | / + * [CircleMul] [CircleConst] + * | / + * [CircleAdd] + * | + * + * AFTER + * | [CircleOutputExclude] + * +-------------------------------------+ / [CircleConst] + * | | / / + * | [CircleTransposeConv] [CircleConst] + * | [CircleConst] | / + * | / [CircleConst] [CircleMul] [CircleConst] + * | / / | / + * [CircleTransposeConv] [CircleAdd] + * | + * ([CircleRelu6]) + * | + * + * Note: CircleRelu6 is inserted if Add activation is ReLU6 + */ +bool fused_batch_norm_with_tconv(luci::CircleAdd *add) +{ + assert(add != nullptr); + + // Find the pattern of CircleTransposeConv - CircleMul - CircleAdd + luci::CircleConst *scale = nullptr; + luci::CircleConst *shift = nullptr; + luci::CircleTransposeConv *tconv = nullptr; + luci::CircleMul *mul = nullptr; + if (not luci::fill(&shift, &mul).with_commutative_args_of(add)) + return false; + if (not luci::fill(&scale, &tconv).with_commutative_args_of(mul)) + return false; + + // check scale and shift constant attributes + if (scale->rank() != 1) + return false; + if (shift->rank() != 1) + return false; + // check mul, add attributes + if (mul->dtype() != loco::DataType::FLOAT32) + return false; + if (add->dtype() != loco::DataType::FLOAT32) + return false; + if (add->fusedActivationFunction() != luci::FusedActFunc::NONE && + add->fusedActivationFunction() != luci::FusedActFunc::RELU6) + return false; + + // tconv bias should be not set + if (not dynamic_cast<luci::CircleOutputExclude *>(tconv->bias())) + return false; + + // get weight of tconv + auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter()); + if (not filter) + return false; + if (filter->dtype() != loco::DataType::FLOAT32) + return false; + if (filter->rank() != 4) + return false; + + auto filter_out_chn = filter->dim(0).value(); + if (filter_out_chn != scale->dim(0).value()) + return false; + if (filter_out_chn != shift->dim(0).value()) + return false; + + auto name = add->name(); + assert(name.length() > 0); + + loco::Graph *graph = add->graph(); + luci::CircleTransposeConv *fused_tconv = graph->nodes()->create<luci::CircleTransposeConv>(); + luci::CircleConst *fused_filter = graph->nodes()->create<luci::CircleConst>(); + luci::CircleConst *fused_bias = graph->nodes()->create<luci::CircleConst>(); + + auto filter_height = filter->dim(1).value(); + auto filter_width = filter->dim(2).value(); + auto filter_in_chn = filter->dim(3).value(); + + // Copy filter shape + fused_filter->dtype(filter->dtype()); + fused_filter->size<loco::DataType::FLOAT32>(filter->size<loco::DataType::FLOAT32>()); + fused_filter->rank(4); + fused_filter->dim(0).set(filter_out_chn); + fused_filter->dim(1).set(filter_height); + fused_filter->dim(2).set(filter_width); + fused_filter->dim(3).set(filter_in_chn); + fused_filter->shape_status(luci::ShapeStatus::VALID); + fused_filter->name(name + "/TransposeConv/filter"); + + // fused filter weight = filter weight * mul(scale) + add(shift) + for (uint32_t c = 0; c < filter_out_chn; c++) + { + for (uint32_t h = 0; h < filter_height; h++) + { + for (uint32_t w = 0; w < filter_width; w++) + { + for (uint32_t b = 0; b < filter_in_chn; b++) + { + uint32_t offset = c * filter_height * filter_width * filter_in_chn + + h * filter_width * filter_in_chn + w * filter_in_chn + b; + fused_filter->at<loco::DataType::FLOAT32>(offset) = + filter->at<loco::DataType::FLOAT32>(offset) * scale->at<loco::DataType::FLOAT32>(c); + } + } + } + } + + // Copy fused_bias from shift + fused_bias->dtype(shift->dtype()); + fused_bias->size<loco::DataType::FLOAT32>(shift->size<loco::DataType::FLOAT32>()); + fused_bias->rank(1); + fused_bias->dim(0).set(filter_out_chn); + fused_bias->shape_status(luci::ShapeStatus::VALID); + for (uint32_t c = 0; c < filter_out_chn; ++c) + { + fused_bias->at<loco::DataType::FLOAT32>(c) = shift->at<loco::DataType::FLOAT32>(c); + } + fused_bias->name(name + "/TransposeConv/bias"); + + // set new tconv properties + fused_tconv->inputSizes(tconv->inputSizes()); + fused_tconv->filter(fused_filter); + fused_tconv->outBackprop(tconv->outBackprop()); + fused_tconv->bias(fused_bias); + fused_tconv->padding(tconv->padding()); + fused_tconv->stride()->h(tconv->stride()->h()); + fused_tconv->stride()->w(tconv->stride()->w()); + fused_tconv->name(name + "/TransposeConv"); + luci::add_origin(fused_tconv, + luci::composite_origin( + {luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)})); + + if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6) + { + // separate relu op from add op + auto relu = add->graph()->nodes()->create<luci::CircleRelu6>(); + relu->features(fused_tconv); + relu->name(name + "/Relu6"); + luci::add_origin(relu, luci::get_origin(add)); + + replace(add).with(relu); + } + else + { + replace(add).with(fused_tconv); + } + + return true; +} + +} // namespace + +namespace luci +{ + +bool FuseBatchNormWithTConvPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto add = dynamic_cast<luci::CircleAdd *>(node)) + { + if (fused_batch_norm_with_tconv(add)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.test.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.test.cpp new file mode 100644 index 000000000..051100dc9 --- /dev/null +++ b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.test.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseBatchNormWithTConvPass.h" + +#include <gtest/gtest.h> + +TEST(FuseBatchNormWithTConvPassTest, name) +{ + luci::FuseBatchNormWithTConvPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/FuseInstanceNormPass.cpp b/compiler/luci/pass/src/FuseInstanceNormPass.cpp index 237152f98..ab7baa1fa 100644 --- a/compiler/luci/pass/src/FuseInstanceNormPass.cpp +++ b/compiler/luci/pass/src/FuseInstanceNormPass.cpp @@ -15,105 +15,16 @@ */ #include "luci/Pass/FuseInstanceNormPass.h" +#include "helpers/NodeFiller.h" #include "FuseInstanceNormPassInternal.h" #include <luci/IR/CircleNodes.h> -#include <loco/Service/ShapeInference.h> +#include <luci/Profile/CircleNodeOrigin.h> #include <cassert> #include <set> -// Helper to find commutative node's arguments -namespace -{ - -/** - * INTRODUCTION - * Binary operation f(x,y) is 'commutative' when - * f(x,y) == f(y,x) holds for all x, y. - * For examples, ADD, MUL and SQUARED_DIFFERENCE are commutative. - * These helpers make it easy to find commutative arguemnts of commtative node. - * - * HOW TO USE - * COMM_NODE *node; - * ARG_TYPE_1 *arg1; - * ARG_TYPE_2 *arg2; - * - * bool ok = fill(&arg1, &arg2).with_commutative_args_of(node); - * - * Result - * If 'node's commutative argument types are actually {ARG_TYPE_1, ARG_TYPE_2} - * (as a set), 'arg1' and 'arg2' set as actual 'node's arguemnts with matching - * type, and return value 'ok' is true. - * Otherwise, 'arg1' and 'arg2' not changed, 'ok' is false. - */ - -template <class ARG_TYPE_1, class ARG_TYPE_2> class NodeFiller final -{ -public: - NodeFiller(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) : _arg_1(arg_1), _arg_2(arg_2) - { - // DO NOTHING - } - - /** - * @return true When 'node's argument types are 'ARG_TYPE_1' and 'ARG_TYPE_2' - * In such case, it assign '_arg_1' and '_arg_2' to actual arguments - * - * @return false When 'node's argument types are NOT matched with 'ARG_TYPE_*' - * In such case, it does not amend '_arg_1' and '_arg_2' - * - * @require COMM_NODE has member x() and y() - */ - template <class COMM_NODE> bool with_commutative_args_of(const COMM_NODE *node); - -private: - ARG_TYPE_1 **_arg_1; - ARG_TYPE_2 **_arg_2; -}; - -template <class ARG_TYPE_1, class ARG_TYPE_2> -inline NodeFiller<ARG_TYPE_1, ARG_TYPE_2> fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) -{ - return NodeFiller<ARG_TYPE_1, ARG_TYPE_2>{arg_1, arg_2}; -} - -template <class ARG_TYPE_1, class ARG_TYPE_2> -template <class COMM_NODE> -bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(const COMM_NODE *node) -{ - // Case 1) X == ARG_TYPE_1 / Y == ARG_TYPE_2 - { - auto x = dynamic_cast<ARG_TYPE_1 *>(node->x()); - auto y = dynamic_cast<ARG_TYPE_2 *>(node->y()); - - if (x && y) - { - *_arg_1 = x; - *_arg_2 = y; - return true; - } - } - - // Case 2) X == ARG_TYPE_2 / Y == ARG_TYPE_1 - { - auto x = dynamic_cast<ARG_TYPE_2 *>(node->x()); - auto y = dynamic_cast<ARG_TYPE_1 *>(node->y()); - - if (x && y) - { - *_arg_1 = y; - *_arg_2 = x; - return true; - } - } - - return false; -} - -} // namespace - // Helper to check detail /// @return true When node has shape of '1 x .. x 1 x depth' @@ -150,11 +61,10 @@ bool is_instance_mean_v0(luci::CircleMean *mean) // // CHECK 1) input is rank 4 // - auto input = mean->input(); - if (not loco::shape_known(input)) + auto input = loco::must_cast<luci::CircleNode *>(mean->input()); + if (input->shape_status() != luci::ShapeStatus::VALID) return false; - auto input_shape = loco::shape_get(input).as<loco::TensorShape>(); - if (input_shape.rank() != 4) + if (input->rank() != 4) return false; // @@ -195,11 +105,10 @@ bool is_instance_mean_v1(luci::CircleMean *mean) // // CHECK 1) input is rank 5 (NHWCX) // - auto input = mean->input(); - if (not loco::shape_known(input)) + auto input = loco::must_cast<luci::CircleNode *>(mean->input()); + if (input->shape_status() != luci::ShapeStatus::VALID) return false; - auto input_shape = loco::shape_get(input).as<loco::TensorShape>(); - if (input_shape.rank() != 5) + if (input->rank() != 5) return false; // @@ -445,8 +354,9 @@ bool InstanceNormPattern::matched() // So it is handled in the separate if statement if (_pv == PatternVersion::Version_2) { - CHECK_OR_FALSE(fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal)); - CHECK_OR_FALSE(fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma)); + CHECK_OR_FALSE( + luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal)); + CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma)); sub = dynamic_cast<luci::CircleSub *>(div->x()); CHECK_OR_FALSE(sub); @@ -456,6 +366,7 @@ bool InstanceNormPattern::matched() luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm); CHECK_OR_FALSE(ifm_node->rank() == 4); + CHECK_OR_FALSE(ifm_node->dim(3).known()); uint32_t ifm_channel_depth = ifm_node->dim(3).value(); mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y()); @@ -477,7 +388,7 @@ bool InstanceNormPattern::matched() CHECK_OR_FALSE(zero_point_five->at<loco::DataType::FLOAT32>(0) == 0.5); CHECK_OR_FALSE( - fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance)); + luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance)); CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32); // TODO Support regarding broadcast CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1); @@ -489,7 +400,8 @@ bool InstanceNormPattern::matched() loco::Node *ifm_should_be = nullptr; luci::CircleMean *mean_of_ifm_should_be = nullptr; - CHECK_OR_FALSE(fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff)); + CHECK_OR_FALSE( + luci::fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff)); CHECK_OR_FALSE(ifm == ifm_should_be); CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be); @@ -503,25 +415,25 @@ bool InstanceNormPattern::matched() if (_pv == PatternVersion::Version_0) { - CHECK_OR_FALSE(fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal)); - CHECK_OR_FALSE(fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm)); + CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal)); + CHECK_OR_FALSE(luci::fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm)); } if (_pv == PatternVersion::Version_1) { - CHECK_OR_FALSE(fill(&mul_as_scaled_reshape, &sub).with_commutative_args_of(add_as_terminal)); CHECK_OR_FALSE( - fill(&reshape_of_ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_reshape)); + luci::fill(&mul_as_scaled_reshape, &sub).with_commutative_args_of(add_as_terminal)); + CHECK_OR_FALSE( + luci::fill(&reshape_of_ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_reshape)); ifm = reshape_of_ifm->tensor(); } - CHECK_OR_FALSE(loco::shape_known(ifm)); - auto ifm_shape = loco::shape_get(ifm); - CHECK_OR_FALSE(ifm_shape.domain() == loco::Domain::Tensor); - auto ifm_tensor_shape = ifm_shape.as<loco::TensorShape>(); - CHECK_OR_FALSE(ifm_tensor_shape.rank() == 4); - uint32_t ifm_channel_depth = ifm_tensor_shape.dim(3).value(); + auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm); + CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID); + CHECK_OR_FALSE(ifm_circle->rank() == 4); + CHECK_OR_FALSE(ifm_circle->dim(3).known()); + uint32_t ifm_channel_depth = ifm_circle->dim(3).value(); - CHECK_OR_FALSE(fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma)); + CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma)); if (_pv == PatternVersion::Version_0) { @@ -536,7 +448,7 @@ bool InstanceNormPattern::matched() CHECK_OR_FALSE(add_as_variance); CHECK_OR_FALSE( - fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance)); + luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance)); CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32); // TODO Support regarding broadcast @@ -557,7 +469,7 @@ bool InstanceNormPattern::matched() if (_pv == PatternVersion::Version_0) { loco::Node *ifm_should_be = nullptr; - CHECK_OR_FALSE(fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff)); + CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff)); CHECK_OR_FALSE(ifm == ifm_should_be); CHECK_OR_FALSE(is_instance_mean_v0(mean_of_ifm)); CHECK_OR_FALSE(ifm == mean_of_ifm->input()); @@ -565,7 +477,8 @@ bool InstanceNormPattern::matched() if (_pv == PatternVersion::Version_1) { loco::Node *reshape_should_be = nullptr; - CHECK_OR_FALSE(fill(&reshape_should_be, &mean_of_reshape).with_commutative_args_of(sqdiff)); + CHECK_OR_FALSE( + luci::fill(&reshape_should_be, &mean_of_reshape).with_commutative_args_of(sqdiff)); CHECK_OR_FALSE(reshape_of_ifm == reshape_should_be); CHECK_OR_FALSE(is_instance_mean_v1(mean_of_reshape)); CHECK_OR_FALSE(reshape_of_ifm == mean_of_reshape->input()); @@ -592,15 +505,15 @@ bool InstanceNormPattern::matched() if (_pv == PatternVersion::Version_0) { - CHECK_OR_FALSE(fill(&mul_gamma_should_be, &mean_of_ifm_should_be) - .with_commutative_args_of(mul_as_scaled_mean)); + CHECK_OR_FALSE(luci::fill(&mul_gamma_should_be, &mean_of_ifm_should_be) + .with_commutative_args_of(mul_as_scaled_mean)); CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be); CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be); } if (_pv == PatternVersion::Version_1) { - CHECK_OR_FALSE(fill(&mul_gamma_should_be, &mean_of_reshape_should_be) - .with_commutative_args_of(mul_as_scaled_mean)); + CHECK_OR_FALSE(luci::fill(&mul_gamma_should_be, &mean_of_reshape_should_be) + .with_commutative_args_of(mul_as_scaled_mean)); CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be); CHECK_OR_FALSE(mean_of_reshape == mean_of_reshape_should_be); } @@ -631,47 +544,59 @@ void fuse_instance_norm(const InstanceNormPattern &p) auto graph = p.add_as_terminal->graph(); - // Special case for version 2 (no need to reshape) - if (p.version() == InstanceNormPattern::Version_2) + // Version 0 and 1 need to reshape + if (p.version() != InstanceNormPattern::Version_2) { - // Make Instance Norm to replace - auto instance_norm = graph->nodes()->create<luci::CircleInstanceNorm>(); - instance_norm->input(p.ifm); - instance_norm->gamma(p.const_as_gamma); - instance_norm->beta(p.const_as_beta); - float epsilon = p.const_as_epsilon->at<loco::DataType::FLOAT32>(0); - instance_norm->epsilon(epsilon); - instance_norm->fusedActivationFunction(p.add_as_terminal->fusedActivationFunction()); - - replace(p.add_as_terminal).with(instance_norm); - - return; - } - - // Make reshape for gamma & beta - auto reshape_gamma = graph->nodes()->create<luci::CircleReshape>(); - auto reshape_beta = graph->nodes()->create<luci::CircleReshape>(); - { - auto ifm_shape = loco::shape_get(p.ifm).as<loco::TensorShape>(); - uint32_t ifm_channel_depth = ifm_shape.dim(3).value(); - - int32_t new_shape[1] = {static_cast<int32_t>(ifm_channel_depth)}; - - reshape_gamma->tensor(p.const_as_gamma); - reshape_beta->tensor(p.const_as_beta); + p.const_as_gamma->rank(1); + p.const_as_gamma->dim(0).set(p.const_as_gamma->size<loco::DataType::FLOAT32>()); + p.const_as_beta->rank(1); + p.const_as_beta->dim(0).set(p.const_as_beta->size<loco::DataType::FLOAT32>()); - luci::set_new_shape(reshape_gamma, new_shape, 1); - luci::set_new_shape(reshape_beta, new_shape, 1); + p.const_as_gamma->shape_status(luci::ShapeStatus::UNDEFINED); + p.const_as_beta->shape_status(luci::ShapeStatus::UNDEFINED); } // Make Instance Norm to replace auto instance_norm = graph->nodes()->create<luci::CircleInstanceNorm>(); instance_norm->input(p.ifm); - instance_norm->gamma(reshape_gamma); - instance_norm->beta(reshape_beta); + instance_norm->gamma(p.const_as_gamma); + instance_norm->beta(p.const_as_beta); float epsilon = p.const_as_epsilon->at<loco::DataType::FLOAT32>(0); instance_norm->epsilon(epsilon); instance_norm->fusedActivationFunction(p.add_as_terminal->fusedActivationFunction()); + // NOTE unique name should be assigned in export + instance_norm->name("InstanceNorm"); + + // set origin + std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{ + luci::get_origin(p.sqdiff), + luci::get_origin(p.mean_as_variance), + luci::get_origin(p.add_as_variance), + luci::get_origin(p.mul_gamma), + luci::get_origin(p.sub), + luci::get_origin(p.add_as_terminal)}; + if (p.version() == InstanceNormPattern::PatternVersion::Version_0) + { + origin_vec.push_back(luci::get_origin(p.mean_of_ifm)); + origin_vec.push_back(luci::get_origin(p.rsqrt)); + origin_vec.push_back(luci::get_origin(p.mul_as_scaled_ifm)); + origin_vec.push_back(luci::get_origin(p.mul_as_scaled_mean)); + } + if (p.version() == InstanceNormPattern::PatternVersion::Version_1) + { + origin_vec.push_back(luci::get_origin(p.reshape_of_ifm)); + origin_vec.push_back(luci::get_origin(p.mean_of_reshape)); + origin_vec.push_back(luci::get_origin(p.rsqrt)); + origin_vec.push_back(luci::get_origin(p.mul_as_scaled_mean)); + origin_vec.push_back(luci::get_origin(p.mul_as_scaled_reshape)); + } + if (p.version() == InstanceNormPattern::PatternVersion::Version_2) + { + origin_vec.push_back(luci::get_origin(p.mean_of_ifm)); + origin_vec.push_back(luci::get_origin(p.pow)); + origin_vec.push_back(luci::get_origin(p.div)); + } + luci::add_origin(instance_norm, luci::composite_origin(origin_vec)); replace(p.add_as_terminal).with(instance_norm); } diff --git a/compiler/luci/pass/src/FuseInstanceNormPass.test.cpp b/compiler/luci/pass/src/FuseInstanceNormPass.test.cpp index 3037f3def..b83ccca50 100644 --- a/compiler/luci/pass/src/FuseInstanceNormPass.test.cpp +++ b/compiler/luci/pass/src/FuseInstanceNormPass.test.cpp @@ -16,6 +16,8 @@ #include "FuseInstanceNormPassInternal.h" +#include "luci/Pass/FuseInstanceNormPass.h" + #include <vector> #include <gtest/gtest.h> @@ -34,6 +36,13 @@ void setShape(luci::CircleNode &node, const std::vector<int> &v) } // namespace +TEST(FuseInstanceNormPassTest, name) +{ + luci::FuseInstanceNormPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + TEST(FuseInstanceNormPass, is_quasi_1D_with_dummy_dim) { luci::CircleConst const_node; diff --git a/compiler/luci/pass/src/FusePreActivationBatchNormPass.cpp b/compiler/luci/pass/src/FusePreActivationBatchNormPass.cpp index bcde5fac4..469fcddbb 100644 --- a/compiler/luci/pass/src/FusePreActivationBatchNormPass.cpp +++ b/compiler/luci/pass/src/FusePreActivationBatchNormPass.cpp @@ -16,9 +16,11 @@ #include "luci/Pass/FusePreActivationBatchNormPass.h" #include "FusePreActivationBatchNormPassInternal.h" +#include "BatchNormPatternFinder.h" #include <luci/IR/CircleNodes.h> #include <luci/Log.h> +#include <luci/Profile/CircleNodeOrigin.h> namespace { @@ -37,83 +39,6 @@ bool is_non_negative(const luci::CircleConst *node) return true; } -// Check if mul is batchnorm mul -bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node, - luci::CircleConst *&gamma) -{ - auto x = dynamic_cast<luci::CircleConst *>(mul->x()); - auto y = dynamic_cast<luci::CircleConst *>(mul->y()); - - luci::CircleNode *pred = nullptr; - luci::CircleConst *constant = nullptr; - - if (x != nullptr && y == nullptr) - { - pred = loco::must_cast<luci::CircleNode *>(mul->y()); - constant = x; - } - else if (x == nullptr && y != nullptr) - { - pred = loco::must_cast<luci::CircleNode *>(mul->x()); - constant = y; - } - else - { - return false; - } - - if (constant->rank() != 1) - return false; - - auto channel_dim = constant->dim(0); - if (!(channel_dim == mul->dim(mul->rank() - 1))) - return false; - - pred_node = pred; - gamma = constant; - return true; -} - -// Check if add is batchnorm add -bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::CircleConst *&beta) -{ - auto x = loco::must_cast<luci::CircleNode *>(add->x()); - auto y = loco::must_cast<luci::CircleNode *>(add->y()); - - luci::CircleMul *pred = nullptr; - luci::CircleConst *constant = nullptr; - - if (add->fusedActivationFunction() != luci::FusedActFunc::RELU) - return false; - - if (x->opcode() == luci::CircleOpcode::CIRCLECONST && y->opcode() == luci::CircleOpcode::MUL) - { - pred = loco::must_cast<luci::CircleMul *>(y); - constant = loco::must_cast<luci::CircleConst *>(x); - } - else if (x->opcode() == luci::CircleOpcode::MUL && y->opcode() == luci::CircleOpcode::CIRCLECONST) - { - pred = loco::must_cast<luci::CircleMul *>(x); - constant = loco::must_cast<luci::CircleConst *>(y); - } - else - { - return false; - } - - if (constant->rank() != 1) - return false; - - auto channel_dim = constant->dim(0); - // Assumption: Layout is channel-last - if (!(channel_dim == add->dim(add->rank() - 1))) - return false; - - mul = pred; - beta = constant; - return true; -} - const luci::CircleConv2D *get_forward_conv2d(const luci::CircleNode *node, uint32_t channel_size) { auto opcode = node->opcode(); @@ -249,6 +174,9 @@ bool update_conv_bias_with_beta(luci::CircleConv2D *conv, const luci::CircleCons auto size = beta->dim(0).value(); auto bias = dynamic_cast<luci::CircleConst *>(conv->bias()); + auto name = conv->name(); + assert(name.length() > 0); + if (bias == nullptr) { bias = conv->graph()->nodes()->create<luci::CircleConst>(); @@ -256,6 +184,7 @@ bool update_conv_bias_with_beta(luci::CircleConv2D *conv, const luci::CircleCons bias->rank(1); bias->dim(0).set(size); bias->size<loco::DataType::FLOAT32>(size); + bias->name(name + "/bias"); conv->bias(bias); } else @@ -282,14 +211,12 @@ bool update_conv_bias_with_beta(luci::CircleConv2D *conv, const luci::CircleCons luci::CircleSub *insert_sub(luci::CircleNode *pred, luci::CircleConst *beta) { + auto name = pred->name(); + assert(name.length() > 0); + auto sub = pred->graph()->nodes()->create<luci::CircleSub>(); - sub->dtype(loco::DataType::FLOAT32); - sub->rank(pred->rank()); - for (uint32_t i = 0; i < sub->rank(); i++) - { - sub->dim(i).set(pred->dim(i).value()); - } sub->fusedActivationFunction(luci::FusedActFunc::NONE); + sub->name(name + "/Sub"); loco::replace(pred).with(sub); @@ -366,6 +293,8 @@ bool fuse_sub_with_conv(luci::CircleSub *sub) if (!update_conv_bias_with_beta(conv, beta, false)) return false; + luci::add_origin(conv, luci::get_origin(sub)); + auto pred = sub->x(); loco::replace(sub).with(pred); @@ -442,6 +371,7 @@ bool fuse_add_with_conv(luci::CircleAdd *add, std::vector<luci::CircleSub *> &su if (!update_conv_bias_with_beta(conv, beta, true)) return false; + luci::add_origin(conv, luci::get_origin(add)); loco::replace(add).with(pred); add->drop(); @@ -462,6 +392,8 @@ bool fuse_add_with_conv(luci::CircleAdd *add, std::vector<luci::CircleSub *> &su if (!update_conv_bias_with_beta(conv, beta, true)) return false; + luci::add_origin(conv, luci::get_origin(add)); + auto relu = *loco::succs(add).begin(); auto relu_node = loco::must_cast<luci::CircleRelu *>(relu); assert(relu_node != nullptr); @@ -471,6 +403,7 @@ bool fuse_add_with_conv(luci::CircleAdd *add, std::vector<luci::CircleSub *> &su add->drop(); sub_list.push_back(insert_sub(pred, beta)); + luci::add_origin(sub_list.back(), luci::get_origin(add)); relu_node->features(pred); @@ -530,6 +463,11 @@ bool fuse_mul_with_conv(luci::CircleMul *mul) // Update CONV weights update_conv_weights_with_gamma(conv, gamma); + + // Update origin + // TODO need to remove const + luci::add_origin(const_cast<luci::CircleConv2D *>(conv), + luci::get_origin(loco::must_cast<luci::CircleNode *>(mul))); } loco::replace(mul).with(pred_node); @@ -568,6 +506,8 @@ bool swap_mul_add(luci::CircleAdd *add, std::vector<luci::CircleMul *> &mul_list if (!is_batchnorm_add(add, mul, beta)) return false; + if (add->fusedActivationFunction() != luci::FusedActFunc::RELU) + return false; if (loco::succs(mul).size() != 1) return false; @@ -582,8 +522,13 @@ bool swap_mul_add(luci::CircleAdd *add, std::vector<luci::CircleMul *> &mul_list return false; // Insert Relu at the bottom + auto name = add->name(); + assert(name.length() > 0); + auto relu = add->graph()->nodes()->create<luci::CircleRelu>(); relu->features(mul); + relu->name(name + "/Relu"); + luci::add_origin(relu, luci::get_origin(add)); loco::replace(add).with(relu); // Replace beta <- beta / gamma diff --git a/compiler/luci/pass/src/FusePreActivationBatchNormPass.test.cpp b/compiler/luci/pass/src/FusePreActivationBatchNormPass.test.cpp index a79b5bd5d..3d5791c9e 100644 --- a/compiler/luci/pass/src/FusePreActivationBatchNormPass.test.cpp +++ b/compiler/luci/pass/src/FusePreActivationBatchNormPass.test.cpp @@ -16,6 +16,8 @@ #include "FusePreActivationBatchNormPassInternal.h" +#include "luci/Pass/FusePreActivationBatchNormPass.h" + #include <luci/IR/CircleNodes.h> #include <math.h> @@ -148,6 +150,22 @@ public: conv_filter->at<loco::DataType::FLOAT32>(i * out_size + j) = i * out_size + j; } } + + pred_conv->name("pred_conv"); + pred_conv_filter->name("pred_conv_filter"); + pred_conv_bias->name("pred_conv_bias"); + pred_conv2->name("pred_conv2"); + pred_conv2_filter->name("pred_conv2_filter"); + pred_conv2_bias->name("pred_conv2_bias"); + pred_add->name("pred_add"); + mul->name("mul"); + mul_gamma->name("mul_gamma"); + add->name("add"); + add_beta->name("add_beta"); + conv->name("conv"); + conv_filter->name("conv_filter"); + conv_bias->name("conv_bias"); + succ_add->name("succ_add"); } public: @@ -171,6 +189,13 @@ public: } // namespace +TEST(FusePreActivationBatchNormPassTest, name) +{ + luci::FusePreActivationBatchNormPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + TEST(FusePreActivationBatchNorm, swap_mul_add) { SimpleGraph g; diff --git a/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.cpp b/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.cpp index 281d1b081..96776dc92 100644 --- a/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.cpp +++ b/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.cpp @@ -16,6 +16,8 @@ #include "luci/Pass/MakeBatchNormGammaPositivePass.h" +#include "BatchNormPatternFinder.h" + #include <luci/IR/CircleNodes.h> namespace @@ -39,71 +41,27 @@ bool negative_gamma_to_positive(luci::CircleConst *gamma) return changed; } -// Check if add is batchnorm add -bool is_batchnorm_add(const luci::CircleAdd *add) +bool make_positive_gamma(luci::CircleAdd *add) { - auto x = dynamic_cast<luci::CircleConst *>(add->x()); - auto y = dynamic_cast<luci::CircleConst *>(add->y()); - - luci::CircleConst *constant = nullptr; + luci::CircleMul *mul = nullptr; + luci::CircleConst *beta = nullptr; + luci::CircleConst *gamma = nullptr; + luci::CircleNode *pred = nullptr; - if (x != nullptr && y == nullptr) - constant = x; - else if (x == nullptr && y != nullptr) - constant = y; - else + if (!is_batchnorm_add(add, mul, beta)) return false; - if (constant->rank() != 1) + if (loco::succs(mul).size() != 1) return false; + if (!is_batchnorm_mul(mul, pred, gamma)) + return false; + assert(pred == add); // Only support Relu if (add->fusedActivationFunction() != luci::FusedActFunc::RELU) return false; - auto channel_dim = constant->dim(0); - if (!(channel_dim == add->dim(add->rank() - 1))) - return false; - - return true; -} - -// Check if mul is batchnorm mul -bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleConst *&gamma) -{ - auto x = dynamic_cast<luci::CircleConst *>(mul->x()); - auto y = dynamic_cast<luci::CircleConst *>(mul->y()); - - luci::CircleConst *constant = nullptr; - - if (x != nullptr && y == nullptr) - constant = x; - else if (x == nullptr && y != nullptr) - constant = y; - else - return false; - - if (constant->rank() != 1) - return false; - - auto channel_dim = constant->dim(0); - if (!(channel_dim == mul->dim(mul->rank() - 1))) - return false; - - // Check successor is batchnorm add - auto succs = loco::succs(mul); - if (succs.size() != 1) - return false; - - auto add = dynamic_cast<luci::CircleAdd *>(*succs.begin()); - if (add == nullptr) - return false; - - if (!is_batchnorm_add(add)) - return false; - - gamma = constant; - return true; + return negative_gamma_to_positive(gamma); } } // namespace @@ -111,18 +69,29 @@ bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleConst *&gamma) namespace luci { +/** + * Make negative gamma values of Mul-Add (as BatchNorm) to a small positive value (1e-10) + * + * PATTERN: + * | + * [CircleNode] [CircleConst](as gamma) + * | | + * [CircleMul] [CircleConst] + * | | + * [CircleAdd] + * | + */ bool MakeBatchNormGammaPositivePass::run(loco::Graph *g) { bool changed = false; for (auto node : loco::active_nodes(loco::output_nodes(g))) { - auto mul = dynamic_cast<luci::CircleMul *>(node); - if (mul == nullptr) + auto add = dynamic_cast<luci::CircleAdd *>(node); + if (add == nullptr) continue; - luci::CircleConst *gamma; - if (is_batchnorm_mul(mul, gamma)) - changed = negative_gamma_to_positive(gamma); + if (make_positive_gamma(add)) + changed = true; } return changed; } diff --git a/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.test.cpp b/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.test.cpp new file mode 100644 index 000000000..83093edc8 --- /dev/null +++ b/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.test.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/MakeBatchNormGammaPositivePass.h" + +#include <gtest/gtest.h> + +TEST(MakeBatchNormGammaPositivePassTest, name) +{ + luci::MakeBatchNormGammaPositivePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp b/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp deleted file mode 100644 index beb962a05..000000000 --- a/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "luci/Pass/MigrateLegacyShapeDtypePass.h" - -#include <loco/Service/ShapeInference.h> -#include <loco/Service/TypeInference.h> - -#include <luci/IR/CircleNodes.h> - -#include <loco.h> - -namespace -{ - -bool has_same_shape(luci::CircleNode *node, loco::TensorShape shape) -{ - if (node->rank() != shape.rank()) - return false; - - for (uint32_t i = 0; i < shape.rank(); ++i) - if (!(node->dim(i) == shape.dim(i))) - return false; - - return true; -} - -} // namespace - -namespace luci -{ - -bool MigrateLegacyShapeDtypePass::run(luci::Module *m) -{ - bool changed = false; - - for (size_t g = 0; g < m->size(); ++g) - { - if (run(m->graph(g))) - changed = true; - } - - return changed; -} - -bool MigrateLegacyShapeDtypePass::run(loco::Graph *g) -{ - bool changed = false; - - for (auto node : loco::all_nodes(g)) - { - auto circle_node = loco::must_cast<luci::CircleNode *>(node); - if (loco::shape_known(node)) - { - auto loco_shape = loco::shape_get(node).as<loco::TensorShape>(); - - assert(circle_node->shape_signature().rank() == 0 || - circle_node->shape_signature().rank() == loco_shape.rank()); - - // When shape of loco is copied to circle node, ShapeSignature should be applied. - loco::TensorShape new_shape; - new_shape.rank(loco_shape.rank()); - for (uint32_t i = 0; i < loco_shape.rank(); ++i) - { - if (circle_node->shape_signature().rank() > 0 && - circle_node->shape_signature().dim(i) == -1) - new_shape.dim(i) = 1; - else - new_shape.dim(i) = loco_shape.dim(i); - } - - if (circle_node->shape_status() == luci::ShapeStatus::UNDEFINED || - !has_same_shape(circle_node, new_shape)) - { - circle_node->rank(new_shape.rank()); - for (uint32_t i = 0; i < new_shape.rank(); ++i) - circle_node->dim(i) = new_shape.dim(i); - - if (circle_node->shape_status() == luci::ShapeStatus::UNDEFINED) - circle_node->shape_status(luci::ShapeStatus::VALID); - - changed = true; - } - } - - if (loco::dtype_known(node)) - { - if (loco::dtype_get(node) != circle_node->dtype()) - { - circle_node->dtype(loco::dtype_get(node)); - changed = true; - } - } - } - - return changed; -} - -} // namespace luci diff --git a/compiler/luci/pass/src/ModulePhase.test.cpp b/compiler/luci/pass/src/ModulePhase.test.cpp new file mode 100644 index 000000000..5d92c59f4 --- /dev/null +++ b/compiler/luci/pass/src/ModulePhase.test.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ModulePhase.h" + +#include "luci/Pass/CircleShapeInferencePass.h" + +#include <loco.h> + +#include <gtest/gtest.h> + +TEST(ModulePhaseTest, saturate) +{ + auto m = luci::make_module(); + auto g = loco::make_graph(); + m->add(std::move(g)); + + luci::Phase phase; + + // Any Pass will do for testing + phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); + + luci::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{m.get()}; + phase_runner.run(phase); + + SUCCEED(); +} + +TEST(ModulePhaseTest, restart) +{ + auto m = luci::make_module(); + auto g = loco::make_graph(); + m->add(std::move(g)); + + luci::Phase phase; + + // Any Pass will do for testing + phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); + + luci::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{m.get()}; + phase_runner.run(phase); + + SUCCEED(); +} diff --git a/compiler/luci/pass/src/PassTestGraphs.h b/compiler/luci/pass/src/PassTestGraphs.h new file mode 100644 index 000000000..f5ae24f0b --- /dev/null +++ b/compiler/luci/pass/src/PassTestGraphs.h @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PASS_TEST_GRAPHS_H__ +#define __LUCI_PASS_TEST_GRAPHS_H__ + +#include <loco.h> +#include <luci/IR/CircleNodes.h> + +namespace luci +{ + +/** + * ConstantFoldingTestGraph is a base class for testing + * constant folding passes. It creates Input and Output + * in the below graph. Child classes must implement Connector + * and Folded pattern. + * + * [Input] [Folded pattern] (Implemented by child class) + * \ / + * [Connector] (Implemented by child class) + * | + * [Output] + * + * Connector should satisfy the below conditions + * - Input type == Output type == Folded pattern type + * - Input shape == Output shape == Folded pattern shape + * + * For example, Add, Mul, Sub, .. can be a Connector + */ +class ConstantFoldingTestGraph +{ +public: + ConstantFoldingTestGraph(std::vector<uint32_t> input_shape, loco::DataType input_dtype) + { + _input = _g.nodes()->create<luci::CircleInput>(); + _output = _g.nodes()->create<luci::CircleOutput>(); + + auto graph_input = _g.inputs()->create(); + _input->index(graph_input->index()); + auto graph_output = _g.outputs()->create(); + _output->index(graph_output->index()); + + graph_input->dtype(input_dtype); + graph_output->dtype(input_dtype); + _input->dtype(input_dtype); + _output->dtype(input_dtype); + + auto input_tensor_shape = std::make_unique<loco::TensorShape>(); + input_tensor_shape->rank(input_shape.size()); + for (int i = 0; i < input_shape.size(); i++) + input_tensor_shape->dim(i).set(input_shape[i]); + graph_input->shape(std::move(input_tensor_shape)); + + auto output_tensor_shape = std::make_unique<loco::TensorShape>(); + output_tensor_shape->rank(input_shape.size()); + for (int i = 0; i < input_shape.size(); i++) + output_tensor_shape->dim(i).set(input_shape[i]); + graph_output->shape(std::move(output_tensor_shape)); + + _input->rank(input_shape.size()); + for (int i = 0; i < input_shape.size(); i++) + _input->dim(i).set(input_shape[i]); + + _output->rank(input_shape.size()); + for (int i = 0; i < input_shape.size(); i++) + _output->dim(i).set(input_shape[i]); + + _input->name("input"); + _output->name("output"); + } + + virtual void init() = 0; + + virtual ~ConstantFoldingTestGraph() = default; + + virtual loco::Node *createFoldedPattern() = 0; + + virtual luci::CircleConst *getFoldedPattern() = 0; + + loco::Graph *graph() { return &_g; } + + // NOTE: we're not adding _ prefix as these class members are public +protected: + loco::Graph _g; + luci::CircleInput *_input = nullptr; + luci::CircleOutput *_output = nullptr; +}; + +/** + * ConstantFoldingTestAddGraph is ConstantFoldingTestGraph + * whose Connector is Add. + */ +class ConstantFoldingAddTestGraph : public ConstantFoldingTestGraph +{ +protected: + ConstantFoldingAddTestGraph(std::vector<uint32_t> input_shape, loco::DataType input_dtype) + : ConstantFoldingTestGraph(input_shape, input_dtype) + { + _add = _g.nodes()->create<luci::CircleAdd>(); + _add->dtype(input_dtype); + + _add->rank(input_shape.size()); + for (int i = 0; i < input_shape.size(); i++) + _add->dim(i).set(input_shape[i]); + + _add->x(_input); + + _output->from(_add); + + _add->name("add"); + } + +protected: + void init() override { _add->y(createFoldedPattern()); } + +protected: + luci::CircleConst *getFoldedPattern() override + { + return dynamic_cast<luci::CircleConst *>(_add->y()); + } + +protected: + luci::CircleAdd *_add = nullptr; +}; + +} // namespace luci + +#endif // __LUCI_PASS_TEST_GRAPHS_H__ diff --git a/compiler/luci/pass/src/ProgressReporter.h b/compiler/luci/pass/src/ProgressReporter.h index cf30da735..8c6c95e65 100644 --- a/compiler/luci/pass/src/ProgressReporter.h +++ b/compiler/luci/pass/src/ProgressReporter.h @@ -30,7 +30,7 @@ class ProgressReporter : public logo::PhaseEventListener { public: ProgressReporter(loco::Graph *graph, logo::PhaseStrategy strategy) - : _graph{graph}, _strategy{strategy} + : _graph{graph}, _strategy{strategy} { // DO NOTHING } @@ -54,7 +54,7 @@ class ModuleProgressReporter : public logo::PhaseEventListener { public: ModuleProgressReporter(luci::Module *module, logo::PhaseStrategy strategy) - : _module{module}, _strategy{strategy} + : _module{module}, _strategy{strategy} { // DO NOTHING } diff --git a/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp index 0f8d562e9..de973a431 100644 --- a/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp +++ b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp @@ -136,30 +136,34 @@ class ConstInputConcatGraph public: ConstInputConcatGraph(loco::DataType quant_type) { - concat_node.dtype(quant_type); - concat_node.fusedActivationFunction(luci::FusedActFunc::NONE); - input_1.dtype(loco::DataType::FLOAT32); - input_1.size<loco::DataType::FLOAT32>(5); + concat_node = g.nodes()->create<luci::CircleConcatenation>(2); + input_1 = g.nodes()->create<luci::CircleConst>(); + input_2 = g.nodes()->create<luci::CircleConv2D>(); + + concat_node->dtype(quant_type); + concat_node->fusedActivationFunction(luci::FusedActFunc::NONE); + input_1->dtype(loco::DataType::FLOAT32); + input_1->size<loco::DataType::FLOAT32>(5); for (int i = 0; i < 5; i++) { // Set data {-2, -1, 0, 1, 2} - input_1.at<loco::DataType::FLOAT32>(i) = i - 2.0; + input_1->at<loco::DataType::FLOAT32>(i) = i - 2.0; } - input_2.dtype(quant_type); + input_2->dtype(quant_type); - concat_node.values(0, &input_1); - concat_node.values(1, &input_2); + concat_node->values(0, input_1); + concat_node->values(1, input_2); if (quant_type == loco::DataType::U8) { - addQuantParam(concat_node, {0.1}, {10}); - addQuantParam(input_2, {2.0}, {2}); + addQuantParam(*concat_node, {0.1}, {10}); + addQuantParam(*input_2, {2.0}, {2}); } else if (quant_type == loco::DataType::S16) { - addQuantParam(concat_node, {0.1}, {0}); - addQuantParam(input_2, {2.0}, {0}); + addQuantParam(*concat_node, {0.1}, {0}); + addQuantParam(*input_2, {2.0}, {0}); } else { @@ -167,16 +171,11 @@ public: } } - ~ConstInputConcatGraph() - { - concat_node.values(0, nullptr); - concat_node.values(1, nullptr); - } - public: - luci::CircleConcatenation concat_node{2}; - luci::CircleConst input_1; - luci::CircleConv2D input_2; + loco::Graph g; + luci::CircleConcatenation *concat_node = nullptr; + luci::CircleConst *input_1 = nullptr; + luci::CircleConv2D *input_2 = nullptr; }; } // namespace @@ -223,19 +222,20 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8) // input_1 is const. const values are quantized with the qparam of concat ConstInputConcatGraph cg(loco::DataType::U8); - luci::propagate_concat_quantparam(&cg.concat_node, loco::DataType::U8); - EXPECT_FLOAT_EQ(0.1, cg.concat_node.quantparam()->scale[0]); - EXPECT_EQ(10, cg.concat_node.quantparam()->zerop[0]); - EXPECT_FLOAT_EQ(0.1, cg.input_1.quantparam()->scale[0]); - EXPECT_EQ(10, cg.input_1.quantparam()->zerop[0]); - EXPECT_FLOAT_EQ(0.1, cg.input_2.quantparam()->scale[0]); - EXPECT_EQ(10, cg.input_2.quantparam()->zerop[0]); - EXPECT_EQ(loco::DataType::U8, cg.input_1.dtype()); - EXPECT_EQ(0, cg.input_1.at<loco::DataType::U8>(0)); - EXPECT_EQ(0, cg.input_1.at<loco::DataType::U8>(1)); - EXPECT_EQ(10, cg.input_1.at<loco::DataType::U8>(2)); - EXPECT_EQ(20, cg.input_1.at<loco::DataType::U8>(3)); - EXPECT_EQ(30, cg.input_1.at<loco::DataType::U8>(4)); + luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::U8); + EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]); + EXPECT_EQ(10, cg.concat_node->quantparam()->zerop[0]); + const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0)); + EXPECT_FLOAT_EQ(0.1, cg_input_1->quantparam()->scale[0]); + EXPECT_EQ(10, cg_input_1->quantparam()->zerop[0]); + EXPECT_FLOAT_EQ(0.1, cg.input_2->quantparam()->scale[0]); + EXPECT_EQ(10, cg.input_2->quantparam()->zerop[0]); + EXPECT_EQ(loco::DataType::U8, cg_input_1->dtype()); + EXPECT_EQ(0, cg_input_1->at<loco::DataType::U8>(0)); + EXPECT_EQ(0, cg_input_1->at<loco::DataType::U8>(1)); + EXPECT_EQ(10, cg_input_1->at<loco::DataType::U8>(2)); + EXPECT_EQ(20, cg_input_1->at<loco::DataType::U8>(3)); + EXPECT_EQ(30, cg_input_1->at<loco::DataType::U8>(4)); } TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8_NEG) @@ -260,20 +260,21 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8_NEG) // concat has fused activation function and input_1 is const. // const values are quantized using its min/max ConstInputConcatGraph cg(loco::DataType::U8); - cg.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU); - luci::propagate_concat_quantparam(&cg.concat_node, loco::DataType::U8); - EXPECT_FLOAT_EQ(0.1, cg.concat_node.quantparam()->scale[0]); - EXPECT_EQ(10, cg.concat_node.quantparam()->zerop[0]); - EXPECT_FLOAT_EQ(0.015686275, cg.input_1.quantparam()->scale[0]); - EXPECT_EQ(128, cg.input_1.quantparam()->zerop[0]); - EXPECT_FLOAT_EQ(2.0, cg.input_2.quantparam()->scale[0]); - EXPECT_EQ(2, cg.input_2.quantparam()->zerop[0]); - EXPECT_EQ(loco::DataType::U8, cg.input_1.dtype()); - EXPECT_EQ(quantize(-2, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(0)); - EXPECT_EQ(quantize(-1, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(1)); - EXPECT_EQ(quantize(0, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(2)); - EXPECT_EQ(quantize(1, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(3)); - EXPECT_EQ(quantize(2, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(4)); + cg.concat_node->fusedActivationFunction(luci::FusedActFunc::RELU); + luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::U8); + EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]); + EXPECT_EQ(10, cg.concat_node->quantparam()->zerop[0]); + const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0)); + EXPECT_FLOAT_EQ(0.015686275, cg_input_1->quantparam()->scale[0]); + EXPECT_EQ(128, cg_input_1->quantparam()->zerop[0]); + EXPECT_FLOAT_EQ(2.0, cg.input_2->quantparam()->scale[0]); + EXPECT_EQ(2, cg.input_2->quantparam()->zerop[0]); + EXPECT_EQ(loco::DataType::U8, cg_input_1->dtype()); + EXPECT_EQ(quantize(-2, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::U8>(0)); + EXPECT_EQ(quantize(-1, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::U8>(1)); + EXPECT_EQ(quantize(0, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::U8>(2)); + EXPECT_EQ(quantize(1, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::U8>(3)); + EXPECT_EQ(quantize(2, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::U8>(4)); } TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16) @@ -318,19 +319,20 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16) // input_1 is const. const values are quantized with the qparam of concat ConstInputConcatGraph cg(loco::DataType::S16); - luci::propagate_concat_quantparam(&cg.concat_node, loco::DataType::S16); - EXPECT_FLOAT_EQ(0.1, cg.concat_node.quantparam()->scale[0]); - EXPECT_EQ(0, cg.concat_node.quantparam()->zerop[0]); - EXPECT_FLOAT_EQ(0.1, cg.input_1.quantparam()->scale[0]); - EXPECT_EQ(0, cg.input_1.quantparam()->zerop[0]); - EXPECT_FLOAT_EQ(0.1, cg.input_2.quantparam()->scale[0]); - EXPECT_EQ(0, cg.input_2.quantparam()->zerop[0]); - EXPECT_EQ(loco::DataType::S16, cg.input_1.dtype()); - EXPECT_EQ(-20, cg.input_1.at<loco::DataType::S16>(0)); - EXPECT_EQ(-10, cg.input_1.at<loco::DataType::S16>(1)); - EXPECT_EQ(0, cg.input_1.at<loco::DataType::S16>(2)); - EXPECT_EQ(10, cg.input_1.at<loco::DataType::S16>(3)); - EXPECT_EQ(20, cg.input_1.at<loco::DataType::S16>(4)); + luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::S16); + EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]); + EXPECT_EQ(0, cg.concat_node->quantparam()->zerop[0]); + const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0)); + EXPECT_FLOAT_EQ(0.1, cg_input_1->quantparam()->scale[0]); + EXPECT_EQ(0, cg_input_1->quantparam()->zerop[0]); + EXPECT_FLOAT_EQ(0.1, cg.input_2->quantparam()->scale[0]); + EXPECT_EQ(0, cg.input_2->quantparam()->zerop[0]); + EXPECT_EQ(loco::DataType::S16, cg_input_1->dtype()); + EXPECT_EQ(-20, cg_input_1->at<loco::DataType::S16>(0)); + EXPECT_EQ(-10, cg_input_1->at<loco::DataType::S16>(1)); + EXPECT_EQ(0, cg_input_1->at<loco::DataType::S16>(2)); + EXPECT_EQ(10, cg_input_1->at<loco::DataType::S16>(3)); + EXPECT_EQ(20, cg_input_1->at<loco::DataType::S16>(4)); } TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16_NEG) @@ -355,18 +357,19 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16_NEG) // concat has fused activation function and input_1 is const. // const values are quantized using its min/max ConstInputConcatGraph cg(loco::DataType::S16); - cg.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU); - luci::propagate_concat_quantparam(&cg.concat_node, loco::DataType::S16); - EXPECT_FLOAT_EQ(0.1, cg.concat_node.quantparam()->scale[0]); - EXPECT_EQ(0, cg.concat_node.quantparam()->zerop[0]); - EXPECT_FLOAT_EQ(0.000061037, cg.input_1.quantparam()->scale[0]); - EXPECT_EQ(0, cg.input_1.quantparam()->zerop[0]); - EXPECT_FLOAT_EQ(2.0, cg.input_2.quantparam()->scale[0]); - EXPECT_EQ(0, cg.input_2.quantparam()->zerop[0]); - EXPECT_EQ(loco::DataType::S16, cg.input_1.dtype()); - EXPECT_EQ(quantize(-2, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(0)); - EXPECT_EQ(quantize(-1, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(1)); - EXPECT_EQ(quantize(0, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(2)); - EXPECT_EQ(quantize(1, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(3)); - EXPECT_EQ(quantize(2, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(4)); + cg.concat_node->fusedActivationFunction(luci::FusedActFunc::RELU); + luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::S16); + EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]); + EXPECT_EQ(0, cg.concat_node->quantparam()->zerop[0]); + const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0)); + EXPECT_FLOAT_EQ(0.000061037, cg_input_1->quantparam()->scale[0]); + EXPECT_EQ(0, cg_input_1->quantparam()->zerop[0]); + EXPECT_FLOAT_EQ(2.0, cg.input_2->quantparam()->scale[0]); + EXPECT_EQ(0, cg.input_2->quantparam()->zerop[0]); + EXPECT_EQ(loco::DataType::S16, cg_input_1->dtype()); + EXPECT_EQ(quantize(-2, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::S16>(0)); + EXPECT_EQ(quantize(-1, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::S16>(1)); + EXPECT_EQ(quantize(0, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::S16>(2)); + EXPECT_EQ(quantize(1, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::S16>(3)); + EXPECT_EQ(quantize(2, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::S16>(4)); } diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.cpp index af83cd83b..26282086b 100644 --- a/compiler/luci/pass/src/PropagateQuantParamPass.cpp +++ b/compiler/luci/pass/src/PropagateQuantParamPass.cpp @@ -91,9 +91,8 @@ bool PropagateQuantParamPass::run(loco::Graph *g) INFO(l) << "PropagateQuantParamPass visit node: " << circle_node->name() << std::endl; PropagateQuantParam pqp; - changed = circle_node->accept(&pqp); - if (changed) - break; + if (circle_node->accept(&pqp)) + changed = true; } return changed; diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp index 15adbfc01..ed1f96828 100644 --- a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp +++ b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp @@ -83,6 +83,13 @@ public: } // namespace +TEST(PropagateQuantParamPassTest, name) +{ + luci::PropagateQuantParamPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + TEST(PropagateQuantParam, simple) { SimpleGraph g; diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp index fa0141114..85d600e47 100644 --- a/compiler/luci/pass/src/QuantizationUtils.cpp +++ b/compiler/luci/pass/src/QuantizationUtils.cpp @@ -96,7 +96,7 @@ void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float data = data < nudged_min ? nudged_min : data; data = data > nudged_max ? nudged_max : data; quantized_values[i] = - static_cast<int32_t>(std::round((data - nudged_min) * scaling_factor_inv)); + static_cast<int32_t>(std::round((data - nudged_min) * scaling_factor_inv)); } node->dtype(loco::DataType::U8); // change the type of tensor @@ -133,14 +133,14 @@ void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float for (uint32_t i = 0; i < size; ++i) { node->at<loco::DataType::S16>(i) = - std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); } } void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp, float &nudged_min, float &nudged_max) { - assert(min != max); + assert(min <= max); const int32_t kMaxScale = std::numeric_limits<int16_t>::max(); const int32_t kMinScale = -kMaxScale; @@ -158,8 +158,8 @@ void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t & scale_factor_from_max_side = rmax / qmax_double; scaling_factor = scale_factor_from_min_side > scale_factor_from_max_side - ? scale_factor_from_min_side - : scale_factor_from_max_side; + ? scale_factor_from_min_side + : scale_factor_from_max_side; zp = 0; nudged_min = static_cast<float>(qmin_double * scaling_factor); nudged_max = static_cast<float>(qmax_double * scaling_factor); @@ -226,7 +226,8 @@ void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t zp = nudged_zero_point; } -bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension, int &channel_dim_index) +bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension, + int32_t &channel_dim_index) { auto succs = loco::succs(node); @@ -304,7 +305,7 @@ bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension, int uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices) { return indices[0] * dimension.dim(1).value() * dimension.dim(2).value() * - dimension.dim(3).value() + + dimension.dim(3).value() + indices[1] * dimension.dim(2).value() * dimension.dim(3).value() + indices[2] * dimension.dim(3).value() + indices[3]; } diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h index 22a5cf1ee..c8c558d3c 100644 --- a/compiler/luci/pass/src/QuantizationUtils.h +++ b/compiler/luci/pass/src/QuantizationUtils.h @@ -37,7 +37,8 @@ void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float float &scaling_factor, int64_t &zp, float &nudged_min, float &nudged_max); -bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension, int &channel_dim_index); +bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension, + int32_t &channel_dim_index); uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices); diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp index e10c4bb4d..e99c7b389 100644 --- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp +++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp @@ -24,33 +24,29 @@ #include <iostream> #include <cmath> - -namespace luci -{ +#include <functional> namespace { -void cal_minmax_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max) +using namespace luci; +using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>; + +void iterate_per_channel(CircleConst *node, IterFunc func) { loco::TensorShape dimension; dimension.rank(4); uint32_t indices[4] = { - 0, + 0, }; - int channel_dim_index{0}; - int size{0}; + int32_t channel_dim_index{0}; if (!get_channel_dim_index(node, dimension, channel_dim_index)) { assert(false); return; } - size = dimension.dim(channel_dim_index).value(); - std::vector<bool> has_min_max_value(size, false); - min.resize(size); - max.resize(size); for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++) { for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++) @@ -59,25 +55,57 @@ void cal_minmax_per_channel(CircleConst *node, std::vector<float> &min, std::vec { for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++) { - int channel_idx = indices[channel_dim_index]; - auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); - if (has_min_max_value[channel_idx]) - { - min[channel_idx] = data < min[channel_idx] ? data : min[channel_idx]; - max[channel_idx] = data > max[channel_idx] ? data : max[channel_idx]; - } - else - { - min[channel_idx] = data; - max[channel_idx] = data; - has_min_max_value[channel_idx] = true; - } + func(indices, dimension, channel_dim_index); } } } } } +} // namespace + +namespace luci +{ + +namespace +{ + +void cal_minmax_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max) +{ + loco::TensorShape dimension; + dimension.rank(4); + int32_t channel_dim_index{0}; + + if (!get_channel_dim_index(node, dimension, channel_dim_index)) + { + assert(false); + return; + } + auto size = dimension.dim(channel_dim_index).value(); + + std::vector<bool> has_min_max_value(size, false); + min.resize(size); + max.resize(size); + + auto cal_minmax = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); + if (has_min_max_value[channel_idx]) + { + min[channel_idx] = data < min[channel_idx] ? data : min[channel_idx]; + max[channel_idx] = data > max[channel_idx] ? data : max[channel_idx]; + } + else + { + min[channel_idx] = data; + max[channel_idx] = data; + has_min_max_value[channel_idx] = true; + } + }; + + iterate_per_channel(node, cal_minmax); +} + void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max, std::vector<float> &scaling_factor, std::vector<int64_t> &zp, std::vector<float> &nudged_min, std::vector<float> &nudged_max) @@ -94,45 +122,24 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vec compute_sym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i]); } - loco::TensorShape dimension; - dimension.rank(4); - uint32_t indices[4] = { - 0, + auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; + auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); + data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data; + data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data; + quantized_values[cal_offset(dimension, indices)] = + static_cast<int32_t>(std::round(data * scaling_factor_inv)); }; - int channel_dim_index{0}; - - if (!get_channel_dim_index(node, dimension, channel_dim_index)) - { - assert(false); - return; - } - for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++) - { - for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++) - { - for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++) - { - for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++) - { - int channel_idx = indices[channel_dim_index]; - const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; - auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); - data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data; - data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data; - quantized_values[cal_offset(dimension, indices)] = - static_cast<int32_t>(std::round(data * scaling_factor_inv)); - } - } - } - } + iterate_per_channel(node, quantize); node->dtype(loco::DataType::S16); // change the type of tensor node->size<loco::DataType::S16>(size); // resize tensor for (uint32_t i = 0; i < size; ++i) { node->at<loco::DataType::S16>(i) = - std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); } } @@ -142,35 +149,14 @@ void sym_wdequant_per_channel(CircleConst *node, std::vector<float> &scaling_fac uint32_t size = node->size<loco::DataType::S16>(); std::vector<float> dequantized_values(size); - loco::TensorShape dimension; - dimension.rank(4); - uint32_t indices[4] = { - 0, + auto dequantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + auto data = node->at<loco::DataType::S16>(cal_offset(dimension, indices)); + dequantized_values[cal_offset(dimension, indices)] = + static_cast<float>(data) * scaling_factor[channel_idx]; }; - int channel_dim_index{0}; - - if (!get_channel_dim_index(node, dimension, channel_dim_index)) - { - assert(false); - return; - } - for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++) - { - for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++) - { - for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++) - { - for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++) - { - int channel_idx = indices[channel_dim_index]; - auto data = node->at<loco::DataType::S16>(cal_offset(dimension, indices)); - dequantized_values[cal_offset(dimension, indices)] = - static_cast<float>(data) * scaling_factor[channel_idx]; - } - } - } - } + iterate_per_channel(node, dequantize); node->dtype(loco::DataType::FLOAT32); // change the type of tensor node->size<loco::DataType::FLOAT32>(size); // resize tensor @@ -198,38 +184,17 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector<float> &min, compute_asym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i]); } - loco::TensorShape dimension; - dimension.rank(4); - uint32_t indices[4] = { - 0, + auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; + auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); + data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data; + data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data; + quantized_values[cal_offset(dimension, indices)] = + static_cast<int32_t>(std::round((data - nudged_min[channel_idx]) * scaling_factor_inv)); }; - int channel_dim_index{0}; - - if (!get_channel_dim_index(node, dimension, channel_dim_index)) - { - assert(false); - return; - } - for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++) - { - for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++) - { - for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++) - { - for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++) - { - int channel_idx = indices[channel_dim_index]; - const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; - auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); - data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data; - data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data; - quantized_values[cal_offset(dimension, indices)] = static_cast<int32_t>( - std::round((data - nudged_min[channel_idx]) * scaling_factor_inv)); - } - } - } - } + iterate_per_channel(node, quantize); node->dtype(loco::DataType::U8); // change the type of tensor node->size<loco::DataType::U8>(size); // resize tensor @@ -246,35 +211,14 @@ void asymmetric_wdequant_per_channel(CircleConst *node, std::vector<float> &scal uint32_t size = node->size<loco::DataType::U8>(); std::vector<float> dequantized_values(size); - loco::TensorShape dimension; - dimension.rank(4); - uint32_t indices[4] = { - 0, + auto dequantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + auto data = node->at<loco::DataType::U8>(cal_offset(dimension, indices)); + dequantized_values[cal_offset(dimension, indices)] = + static_cast<float>(data) * scaling_factor[channel_idx] + nudged_min[channel_idx]; }; - int channel_dim_index{0}; - - if (!get_channel_dim_index(node, dimension, channel_dim_index)) - { - assert(false); - return; - } - for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++) - { - for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++) - { - for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++) - { - for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++) - { - int channel_idx = indices[channel_dim_index]; - auto data = node->at<loco::DataType::U8>(cal_offset(dimension, indices)); - dequantized_values[cal_offset(dimension, indices)] = - static_cast<float>(data) * scaling_factor[channel_idx] + nudged_min[channel_idx]; - } - } - } - } + iterate_per_channel(node, dequantize); node->dtype(loco::DataType::FLOAT32); // change the type of tensor node->size<loco::DataType::FLOAT32>(size); // resize tensor @@ -311,7 +255,7 @@ struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<b { QuantizeDequantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity granularity) - : input_type(input), output_type(output), granularity(granularity) + : input_type(input), output_type(output), granularity(granularity) { } diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp new file mode 100644 index 000000000..f226253c2 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/QuantizeDequantizeWeightsPass.h" + +#include <gtest/gtest.h> + +TEST(QuantizeDequantizeWeightsPassTest, name) +{ + luci::QuantizeDequantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::U8, + luci::QuantizationGranularity::LayerWise); + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp index f6eebe3b9..4707ad0e9 100644 --- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp @@ -19,12 +19,51 @@ #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> +#include <luci/Service/Nodes/CircleConst.h> #include <luci/Log.h> #include <oops/UserExn.h> #include <iostream> #include <cmath> +#include <functional> + +namespace +{ + +using namespace luci; +using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>; + +void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func) +{ + loco::TensorShape dimension; + dimension.rank(4); + uint32_t indices[4] = { + 0, + }; + + if (!get_channel_dim_index(node, dimension, channel_dim_index)) + { + assert(false); + return; + } + + for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++) + { + for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++) + { + for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++) + { + for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++) + { + func(indices, dimension, channel_dim_index); + } + } + } + } +} + +} // namespace namespace luci { @@ -32,6 +71,30 @@ namespace luci namespace { +// Create a new const node from an existing node. +// The new node has the following characteristics +// type: T +// shape: same with 'node' (given as an argument) +// buffer size: 'size' (given as an argument) +// Note that contents are not filled in this function. +template <loco::DataType T> +luci::CircleConst *create_empty_const_from(luci::CircleConst *node, uint32_t size) +{ + auto new_node = node->graph()->nodes()->create<CircleConst>(); + // TODO: We don't have any naming convention for quantized nodes yet. + // Fix this when we have one. + new_node->name(node->name()); + new_node->dtype(T); + new_node->rank(node->rank()); + for (uint32_t i = 0; i < node->rank(); i++) + new_node->dim(i).set(node->dim(i).value()); + + new_node->size<T>(size); + new_node->shape_status(luci::ShapeStatus::VALID); + + return new_node; +} + void overwrite_quantparam(luci::CircleConcatenation *concat, luci::CircleNode *target) { auto concat_qparam = concat->quantparam(); @@ -44,6 +107,9 @@ void overwrite_quantparam(luci::CircleConcatenation *concat, luci::CircleNode *t auto quantparam = std::make_unique<CircleQuantParam>(); target->quantparam(std::move(quantparam)); target_qparam = target->quantparam(); + + if (target_qparam == nullptr) + throw std::runtime_error("Creating new quant param failed"); } target_qparam->min = concat_qparam->min; target_qparam->max = concat_qparam->max; @@ -79,7 +145,7 @@ void quant_const_values(luci::CircleConst *const_node, float scaling_factor, flo const_node->size<loco::DataType::S16>(size); // resize tensor for (uint32_t i = 0; i < size; ++i) const_node->at<loco::DataType::S16>(i) = - std::min(32767, std::max(-32767, quantized_values[i])); + std::min(32767, std::max(-32767, quantized_values[i])); break; default: throw std::runtime_error("Unsupported data type"); @@ -219,17 +285,16 @@ void quant_const(CircleConst *node, loco::DataType quant_type) } // Check if the node is the bias of Conv2D, DepthwiseConv2D, FullyConnected, or TransposeConv layer -// If true, return <input, weight> pair of the successor node (used to quantize bias) -// If flase, return <nullptr, nullptr> -std::pair<loco::Node *, loco::Node *> get_input_weight_of_bias(CircleNode *node) +// Returns a list of <input, weights, output> vectors for the above operators. +// Note that it returns a 'list' because bias can be used by multiple operators. +std::vector<std::vector<loco::Node *>> get_input_weight_output_of_bias(CircleNode *node) { + std::vector<std::vector<loco::Node *>> result; auto circle_const = dynamic_cast<CircleConst *>(node); if (circle_const == nullptr) - return std::make_pair(nullptr, nullptr); + return result; auto succs = loco::succs(node); - if (succs.size() != 1) // assume bias is used by only one node - return std::make_pair(nullptr, nullptr); for (auto out : succs) { @@ -238,35 +303,39 @@ std::pair<loco::Node *, loco::Node *> get_input_weight_of_bias(CircleNode *node) { assert(conv->input() != nullptr); assert(conv->filter() != nullptr); - return std::make_pair(conv->input(), conv->filter()); + result.push_back({conv->input(), conv->filter(), conv}); + continue; } auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out); if (dw_conv != nullptr && dw_conv->bias() == circle_const) { assert(dw_conv->input() != nullptr); assert(dw_conv->filter() != nullptr); - return std::make_pair(dw_conv->input(), dw_conv->filter()); + result.push_back({dw_conv->input(), dw_conv->filter(), dw_conv}); + continue; } auto fc = dynamic_cast<CircleFullyConnected *>(out); if (fc != nullptr && fc->bias() == circle_const) { assert(fc->input() != nullptr); assert(fc->weights() != nullptr); - return std::make_pair(fc->input(), fc->weights()); + result.push_back({fc->input(), fc->weights(), fc}); + continue; } auto tconv = dynamic_cast<CircleTransposeConv *>(out); if (tconv != nullptr && tconv->bias() == circle_const) { assert(tconv->outBackprop() != nullptr); assert(tconv->filter() != nullptr); - return std::make_pair(tconv->outBackprop(), tconv->filter()); + result.push_back({tconv->outBackprop(), tconv->filter(), tconv}); + continue; } } - return std::make_pair(nullptr, nullptr); + return result; } -void asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weight_scale, - float *scaling_factor, int64_t *zp) +CircleConst *asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weight_scale, + float *scaling_factor, int64_t *zp) { float scale = input_scale * weight_scale; const float scaling_factor_inv = (scale == 0) ? 0 : 1.0 / scale; @@ -276,24 +345,27 @@ void asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weigh for (uint32_t i = 0; i < size; ++i) { quantized_values[i] = - static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv)); + static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv)); } - node->dtype(loco::DataType::S32); // change the type of tensor - node->size<loco::DataType::S32>(size); // resize tensor + auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size); + const int32_t kMinScale = std::numeric_limits<int32_t>::lowest(); const int32_t kMaxScale = std::numeric_limits<int32_t>::max(); for (uint32_t i = 0; i < size; ++i) { - node->at<loco::DataType::S32>(i) = - std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + new_bias->at<loco::DataType::S32>(i) = + std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); } *scaling_factor = scale; *zp = 0; + + return new_bias; } -void quant_bias_per_channel(CircleConst *node, float input_scale, std::vector<float> &weight_scale, - std::vector<float> &scaling_factor, std::vector<int64_t> &zp) +CircleConst *quant_bias_per_channel(CircleConst *node, float input_scale, + std::vector<float> &weight_scale, + std::vector<float> &scaling_factor, std::vector<int64_t> &zp) { float scaling_factor_inv{0}; @@ -305,24 +377,27 @@ void quant_bias_per_channel(CircleConst *node, float input_scale, std::vector<fl scaling_factor[i] = input_scale * weight_scale[i]; scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i]; quantized_values[i] = - static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv)); + static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv)); zp[i] = 0; } - node->dtype(loco::DataType::S32); // change the type of tensor - node->size<loco::DataType::S32>(size); // resize tensor + auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size); + const int32_t kMinScale = std::numeric_limits<int32_t>::lowest(); const int32_t kMaxScale = std::numeric_limits<int32_t>::max(); for (uint32_t i = 0; i < size; ++i) { - node->at<loco::DataType::S32>(i) = - std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + new_bias->at<loco::DataType::S32>(i) = + std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); } + + return new_bias; } -void int16_quant_bias_per_channel(CircleConst *node, float input_scale, - std::vector<float> &weight_scale, - std::vector<float> &scaling_factor, std::vector<int64_t> &zp) +CircleConst *int16_quant_bias_per_channel(CircleConst *node, float input_scale, + std::vector<float> &weight_scale, + std::vector<float> &scaling_factor, + std::vector<int64_t> &zp) { float scaling_factor_inv{0}; @@ -334,16 +409,18 @@ void int16_quant_bias_per_channel(CircleConst *node, float input_scale, scaling_factor[i] = input_scale * weight_scale[i]; scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i]; quantized_values[i] = - static_cast<int64_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv)); + static_cast<int64_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv)); zp[i] = 0; } - node->dtype(loco::DataType::S64); // change the type of tensor - node->size<loco::DataType::S64>(size); // resize tensor + auto new_bias = create_empty_const_from<loco::DataType::S64>(node, size); + for (uint32_t i = 0; i < size; ++i) { - node->at<loco::DataType::S64>(i) = quantized_values[i]; + new_bias->at<loco::DataType::S64>(i) = quantized_values[i]; } + + return new_bias; } bool has_min_max(const CircleNode *node) @@ -362,42 +439,22 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_facto uint32_t size = node->size<loco::DataType::FLOAT32>(); std::vector<int32_t> quantized_values(size); - loco::TensorShape dimension; - dimension.rank(4); - uint32_t indices[4] = { - 0, + auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; + auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); + quantized_values[cal_offset(dimension, indices)] = + static_cast<int32_t>(std::round(data * scaling_factor_inv)); }; - if (!get_channel_dim_index(node, dimension, channel_dim_index)) - { - assert(false); - return; - } - - for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++) - { - for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++) - { - for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++) - { - for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++) - { - int channel_idx = indices[channel_dim_index]; - const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; - auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); - quantized_values[cal_offset(dimension, indices)] = - static_cast<int32_t>(std::round(data * scaling_factor_inv)); - } - } - } - } + iterate_per_channel(node, channel_dim_index, quantize); node->dtype(loco::DataType::S16); // change the type of tensor node->size<loco::DataType::S16>(size); // resize tensor for (uint32_t i = 0; i < size; ++i) { node->at<loco::DataType::S16>(i) = - std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); } } @@ -412,35 +469,15 @@ void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min, uint32_t size = node->size<loco::DataType::FLOAT32>(); std::vector<int32_t> quantized_values(size); - loco::TensorShape dimension; - dimension.rank(4); - uint32_t indices[4] = { - 0, + auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; + auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); + quantized_values[cal_offset(dimension, indices)] = + static_cast<int32_t>(std::round((data - min[channel_idx]) * scaling_factor_inv)); }; - if (!get_channel_dim_index(node, dimension, channel_dim_index)) - { - assert(false); - return; - } - - for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++) - { - for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++) - { - for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++) - { - for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++) - { - int channel_idx = indices[channel_dim_index]; - const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; - auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); - quantized_values[cal_offset(dimension, indices)] = - static_cast<int32_t>(std::round((data - min[channel_idx]) * scaling_factor_inv)); - } - } - } - } + iterate_per_channel(node, channel_dim_index, quantize); node->dtype(loco::DataType::U8); // change the type of tensor node->size<loco::DataType::U8>(size); // resize tensor @@ -473,6 +510,21 @@ void asym_wquant_per_layer(CircleConst *node, float min, float scaling_factor) } } +void set_bias(luci::CircleNode *node, luci::CircleConst *bias) +{ + if (auto conv = dynamic_cast<CircleConv2D *>(node)) + conv->bias(bias); + else if (auto dconv = dynamic_cast<CircleDepthwiseConv2D *>(node)) + dconv->bias(bias); + else if (auto tconv = dynamic_cast<CircleTransposeConv *>(node)) + tconv->bias(bias); + else if (auto fc = dynamic_cast<CircleFullyConnected *>(node)) + fc->bias(bias); + else + throw std::runtime_error("Only convolution, depthwise convolution, transposed convolution, and " + "fully-connected layer have bias"); +} + /** * @brief QuantizeActivation quantizes tensors for activations * @details Quantize using recorded min/max values @@ -480,7 +532,7 @@ void asym_wquant_per_layer(CircleConst *node, float min, float scaling_factor) struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool> { QuantizeActivation(loco::DataType input, loco::DataType output) - : input_type(input), output_type(output) + : input_type(input), output_type(output) { } @@ -503,8 +555,12 @@ struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool> continue; // Check if this is bias (bias is quantized later) - auto iw = get_input_weight_of_bias(circle_node); - if (iw.first != nullptr && iw.second != nullptr) + auto iwo = get_input_weight_output_of_bias(circle_node); + if (iwo.size() > 0) + continue; + + // Check if this is bool type (bool type is not quantized) + if (circle_node->dtype() == loco::DataType::BOOL) continue; // Check if this is activation @@ -547,7 +603,7 @@ struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool> struct QuantizeBias final : public luci::CircleNodeMutableVisitor<bool> { QuantizeBias(loco::DataType input, loco::DataType output, QuantizationGranularity gr) - : input_type(input), output_type(output), granularity(gr) + : input_type(input), output_type(output), granularity(gr) { } @@ -562,65 +618,77 @@ struct QuantizeBias final : public luci::CircleNodeMutableVisitor<bool> if (is_quantized(node)) return false; - // Check if this is bias - auto iw = get_input_weight_of_bias(node); - if (iw.first == nullptr || iw.second == nullptr) - return false; - - auto input = loco::must_cast<luci::CircleNode *>(iw.first); - auto weight = loco::must_cast<luci::CircleNode *>(iw.second); + auto iwo_list = get_input_weight_output_of_bias(node); - if (granularity == QuantizationGranularity::ChannelWise) + for (auto iwo : iwo_list) { - assert(input->quantparam()->scale.size() == 1); // input scale's layer-wise - auto input_scale = input->quantparam()->scale[0]; + assert(iwo.size() == 3); - assert(weight->quantparam() != nullptr); // weight scale's channel-wise - auto weight_scale = weight->quantparam()->scale; + auto input = loco::must_cast<luci::CircleNode *>(iwo[0]); + auto weight = loco::must_cast<luci::CircleNode *>(iwo[1]); + auto output = loco::must_cast<luci::CircleNode *>(iwo[2]); - auto circle_const = loco::must_cast<luci::CircleConst *>(node); + auto const_bias = loco::must_cast<luci::CircleConst *>(node); + assert(const_bias->dtype() == loco::DataType::FLOAT32); - uint32_t size = circle_const->size<loco::DataType::FLOAT32>(); - assert(size == weight_scale.size()); - std::vector<float> scaling_factor(size); - std::vector<int64_t> zp(size); + CircleConst *new_bias = nullptr; - if (output_type == loco::DataType::U8) - { - quant_bias_per_channel(circle_const, input_scale, weight_scale, scaling_factor, zp); - } - else if (output_type == loco::DataType::S16) + if (granularity == QuantizationGranularity::ChannelWise) { - int16_quant_bias_per_channel(circle_const, input_scale, weight_scale, scaling_factor, zp); + assert(input->quantparam()->scale.size() == 1); // input scale's layer-wise + auto input_scale = input->quantparam()->scale[0]; + + assert(weight->quantparam() != nullptr); // weight scale's channel-wise + auto weight_scale = weight->quantparam()->scale; + + uint32_t size = const_bias->size<loco::DataType::FLOAT32>(); + assert(size == weight_scale.size()); + std::vector<float> scaling_factor(size); + std::vector<int64_t> zp(size); + + if (output_type == loco::DataType::U8) + { + new_bias = + quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp); + } + else if (output_type == loco::DataType::S16) + { + new_bias = + int16_quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp); + } + else + { + throw std::runtime_error("Unsupported quantization type."); + } + + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->scale = scaling_factor; + quantparam->zerop = zp; + assert(new_bias->quantparam() == nullptr); // bias should not be quantized before + new_bias->quantparam(std::move(quantparam)); + + set_bias(output, new_bias); } else { - throw std::runtime_error("Unsupported quantization type."); - } + assert(input->quantparam()->scale.size() == 1); // Only support per-layer quant + auto input_scale = input->quantparam()->scale[0]; - auto quantparam = std::make_unique<CircleQuantParam>(); - quantparam->scale = scaling_factor; - quantparam->zerop = zp; - assert(circle_const->quantparam() == nullptr); // bias should not be quantized before - circle_const->quantparam(std::move(quantparam)); - } - else - { - assert(input->quantparam()->scale.size() == 1); // Only support per-layer quant - auto input_scale = input->quantparam()->scale[0]; - - assert(weight->quantparam()->scale.size() == 1); // Only support per-layer quant - auto weight_scale = weight->quantparam()->scale[0]; - - auto circle_const = loco::must_cast<luci::CircleConst *>(node); - float scaling_factor{0}; - int64_t zp{0}; - asym_quant_bias_per_layer(circle_const, input_scale, weight_scale, &scaling_factor, &zp); - auto quantparam = std::make_unique<CircleQuantParam>(); - quantparam->scale.push_back(scaling_factor); - quantparam->zerop.push_back(zp); - assert(circle_const->quantparam() == nullptr); // bias should not be quantized before - circle_const->quantparam(std::move(quantparam)); + assert(weight->quantparam()->scale.size() == 1); // Only support per-layer quant + auto weight_scale = weight->quantparam()->scale[0]; + + float scaling_factor{0}; + int64_t zp{0}; + new_bias = + asym_quant_bias_per_layer(const_bias, input_scale, weight_scale, &scaling_factor, &zp); + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->scale.push_back(scaling_factor); + quantparam->zerop.push_back(zp); + assert(new_bias->quantparam() == nullptr); // bias should not be quantized before + new_bias->quantparam(std::move(quantparam)); + + set_bias(output, new_bias); + } } return false; } @@ -633,7 +701,7 @@ struct QuantizeBias final : public luci::CircleNodeMutableVisitor<bool> struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool> { QuantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity gr) - : input_type(input), output_type(output), granularity(gr) + : input_type(input), output_type(output), granularity(gr) { } @@ -641,116 +709,179 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool> loco::DataType output_type; QuantizationGranularity granularity; - // Quantize input tensors of each node - bool visit(luci::CircleNode *node) +private: + void quantize_weights(luci::CircleConst *weights) { - LOGGER(l); - INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl; - auto arity = node->arity(); - for (uint32_t i = 0; i < arity; i++) + // Find min/max per channel-wise + if (granularity == QuantizationGranularity::ChannelWise) { - auto input_node = node->arg(i); - auto circle_node = loco::must_cast<luci::CircleNode *>(input_node); + auto quantparam = weights->quantparam(); + if (quantparam == nullptr) + { + assert(false && "quantparam is nullptr"); + return; + } - // Check if this is already quantized - if (is_quantized(circle_node)) - continue; + auto min = quantparam->min; + auto scaling_factor = quantparam->scale; + int32_t channel_dim_index = 0; - if (is_weights(circle_node)) + if (output_type == loco::DataType::U8) { - auto circle_const = loco::must_cast<luci::CircleConst *>(circle_node); - - // Find min/max per channel-wise - if (granularity == QuantizationGranularity::ChannelWise) - { - auto quantparam = circle_node->quantparam(); - if (quantparam == nullptr) - { - assert(false && "quantparam is nullptr"); - return false; - } - - auto min = quantparam->min; - auto scaling_factor = quantparam->scale; - int32_t channel_dim_index = 0; - - if (output_type == loco::DataType::U8) - { - asym_wquant_per_channel(circle_const, min, scaling_factor, channel_dim_index); - } - else - { - sym_wquant_per_channel(circle_const, scaling_factor, channel_dim_index); - } - quantparam->min.clear(); - quantparam->max.clear(); - quantparam->quantized_dimension = channel_dim_index; - } - // Find min/max per layer-wise - else - { - // Quantize using recorded quantparam - auto quantparam = circle_node->quantparam(); - assert(quantparam != nullptr); - assert(quantparam->min.size() == 1); // only support layer-wise quant - assert(quantparam->scale.size() == 1); // only support layer-wise quant - auto min = quantparam->min[0]; - auto scaling_factor = quantparam->scale[0]; - asym_wquant_per_layer(circle_const, min, scaling_factor); - quantparam->min.clear(); - quantparam->max.clear(); - } + asym_wquant_per_channel(weights, min, scaling_factor, channel_dim_index); + } + else + { + sym_wquant_per_channel(weights, scaling_factor, channel_dim_index); } + quantparam->min.clear(); + quantparam->max.clear(); + quantparam->quantized_dimension = channel_dim_index; + } + // Find min/max per layer-wise + else + { + // Quantize using recorded quantparam + auto quantparam = weights->quantparam(); + assert(quantparam != nullptr); + assert(quantparam->min.size() == 1); // only support layer-wise quant + assert(quantparam->scale.size() == 1); // only support layer-wise quant + auto min = quantparam->min[0]; + auto scaling_factor = quantparam->scale[0]; + asym_wquant_per_layer(weights, min, scaling_factor); + quantparam->min.clear(); + quantparam->max.clear(); } - return false; } -}; -void quant_instnorm(luci::CircleInstanceNorm *node, loco::DataType output_type, - QuantizationGranularity granularity) -{ - auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma()); - auto beta = loco::must_cast<luci::CircleConst *>(node->beta()); - assert(gamma->dtype() == loco::DataType::FLOAT32); - assert(beta->dtype() == loco::DataType::FLOAT32); + bool visit(luci::CircleConv2D *node) + { + LOGGER(l); + INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl; - if (granularity == QuantizationGranularity::LayerWise) + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + if (!is_quantized(weights)) + { + auto new_weights = luci::clone(weights); + node->filter(new_weights); + quantize_weights(new_weights); + return true; + } + return false; + } + + bool visit(luci::CircleDepthwiseConv2D *node) { - quant_const(gamma, output_type); - quant_const(beta, output_type); + LOGGER(l); + INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl; + + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + if (!is_quantized(weights)) + { + auto new_weights = luci::clone(weights); + node->filter(new_weights); + quantize_weights(new_weights); + return true; + } + return false; } - else if (granularity == QuantizationGranularity::ChannelWise) + + bool visit(luci::CircleInstanceNorm *node) { - quant_const_per_channel(gamma, output_type); - quant_const_per_channel(beta, output_type); + LOGGER(l); + INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl; + + auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma()); + auto beta = loco::must_cast<luci::CircleConst *>(node->beta()); + + bool changed = false; + if (!is_quantized(gamma)) + { + assert(gamma->dtype() == loco::DataType::FLOAT32); + auto new_gamma = luci::clone(gamma); + if (granularity == QuantizationGranularity::LayerWise) + quant_const(new_gamma, output_type); + else if (granularity == QuantizationGranularity::ChannelWise) + quant_const_per_channel(new_gamma, output_type); + node->gamma(new_gamma); + changed = true; + } + if (!is_quantized(beta)) + { + assert(beta->dtype() == loco::DataType::FLOAT32); + auto new_beta = luci::clone(beta); + if (granularity == QuantizationGranularity::LayerWise) + quant_const(new_beta, output_type); + else if (granularity == QuantizationGranularity::ChannelWise) + quant_const_per_channel(new_beta, output_type); + node->beta(new_beta); + changed = true; + } + + return changed; } - else - throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'"); -} -void quant_prelu(luci::CirclePRelu *node, loco::DataType output_type, - QuantizationGranularity granularity) -{ - auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha()); - assert(alpha->dtype() == loco::DataType::FLOAT32); + bool visit(luci::CirclePRelu *node) + { + LOGGER(l); + INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl; + + auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha()); + + if (!is_quantized(alpha)) + { + assert(alpha->dtype() == loco::DataType::FLOAT32); + auto new_alpha = luci::clone(alpha); + if (granularity == QuantizationGranularity::LayerWise) + quant_const(new_alpha, output_type); + else if (granularity == QuantizationGranularity::ChannelWise) + quant_const_per_channel(new_alpha, output_type); + node->alpha(new_alpha); + return true; + } - if (granularity == QuantizationGranularity::LayerWise) + return false; + } + + bool visit(luci::CircleTransposeConv *node) { - quant_const(alpha, output_type); + LOGGER(l); + INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl; + + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + if (!is_quantized(weights)) + { + auto new_weights = luci::clone(weights); + node->filter(new_weights); + quantize_weights(new_weights); + return true; + } + return false; } - else if (granularity == QuantizationGranularity::ChannelWise) + + bool visit(luci::CircleFullyConnected *node) { - quant_const_per_channel(alpha, output_type); + LOGGER(l); + INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl; + + auto weights = loco::must_cast<luci::CircleConst *>(node->weights()); + if (!is_quantized(weights)) + { + auto new_weights = luci::clone(weights); + node->weights(new_weights); + quantize_weights(new_weights); + return true; + } + return false; } - else - throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'"); -} + + bool visit(luci::CircleNode *) { return false; } +}; /** * @brief Quantize const input tensors using min/max of const values */ -void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type, - QuantizationGranularity granularity) +void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type) { auto opcode = node->opcode(); auto arity = node->arity(); @@ -763,6 +894,8 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type, case luci::CircleOpcode::CONV_2D: case luci::CircleOpcode::DEPTHWISE_CONV_2D: case luci::CircleOpcode::FULLY_CONNECTED: + case luci::CircleOpcode::INSTANCE_NORM: + case luci::CircleOpcode::PRELU: case luci::CircleOpcode::TRANSPOSE_CONV: // Handled in QuantizeWeights and QuantizeBias break; @@ -771,8 +904,13 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type, // Handled in propagate_concat_quantparam break; + case luci::CircleOpcode::LOGICAL_OR: + // Inputs of logical Ops are bool, thus not quantized + break; + case luci::CircleOpcode::ARG_MAX: case luci::CircleOpcode::ARG_MIN: + case luci::CircleOpcode::BATCH_TO_SPACE_ND: case luci::CircleOpcode::MEAN: case luci::CircleOpcode::PAD: case luci::CircleOpcode::REDUCE_ANY: @@ -783,6 +921,9 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type, case luci::CircleOpcode::RESIZE_BILINEAR: case luci::CircleOpcode::RESIZE_NEAREST_NEIGHBOR: case luci::CircleOpcode::REVERSE_SEQUENCE: + case luci::CircleOpcode::SLICE: + case luci::CircleOpcode::SPACE_TO_BATCH_ND: + case luci::CircleOpcode::STRIDED_SLICE: case luci::CircleOpcode::SUM: case luci::CircleOpcode::TILE: case luci::CircleOpcode::TOPK_V2: @@ -791,41 +932,53 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type, // Ex: axis, paddings input_node = node->arg(0); const_node = dynamic_cast<luci::CircleConst *>(input_node); - if (const_node != nullptr) + if (const_node != nullptr && !is_quantized(const_node)) quant_const(const_node, output_type); break; - case luci::CircleOpcode::INSTANCE_NORM: - quant_instnorm(loco::must_cast<luci::CircleInstanceNorm *>(node), output_type, granularity); - break; - - case luci::CircleOpcode::PRELU: - quant_prelu(loco::must_cast<luci::CirclePRelu *>(node), output_type, granularity); - break; - case luci::CircleOpcode::ADD: case luci::CircleOpcode::ADD_N: + case luci::CircleOpcode::DEPTH_TO_SPACE: case luci::CircleOpcode::DIV: + case luci::CircleOpcode::ELU: case luci::CircleOpcode::EQUAL: + case luci::CircleOpcode::FLOOR: + case luci::CircleOpcode::FLOOR_DIV: case luci::CircleOpcode::GREATER: case luci::CircleOpcode::GREATER_EQUAL: case luci::CircleOpcode::LESS: case luci::CircleOpcode::LESS_EQUAL: + case luci::CircleOpcode::LOGISTIC: case luci::CircleOpcode::MAXIMUM: case luci::CircleOpcode::MINIMUM: case luci::CircleOpcode::MUL: case luci::CircleOpcode::NOT_EQUAL: + case luci::CircleOpcode::POW: + case luci::CircleOpcode::RSQRT: + case luci::CircleOpcode::SOFTMAX: + case luci::CircleOpcode::SPACE_TO_DEPTH: + case luci::CircleOpcode::SQRT: case luci::CircleOpcode::SUB: + case luci::CircleOpcode::TANH: // Quantize all const inputs using their values for (uint32_t i = 0; i < arity; i++) { input_node = node->arg(i); const_node = dynamic_cast<luci::CircleConst *>(input_node); - if (const_node != nullptr) + if (const_node != nullptr && !is_quantized(const_node)) quant_const(const_node, output_type); } break; + case luci::CircleOpcode::SPLIT: + // Only the second input is quantized + // First input should not be quantized (e.g., split_dim) + input_node = node->arg(1); + const_node = dynamic_cast<luci::CircleConst *>(input_node); + if (const_node != nullptr && !is_quantized(const_node)) + quant_const(const_node, output_type); + break; + default: for (uint32_t i = 0; i < arity; i++) { @@ -850,8 +1003,8 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type, * (U8 qparam2) * * AFTER - * [CircleNode] [CircleConst] - * (U8 qparam2) (U8 qparam2) + * [CircleNode] [CircleConst] [CircleConst] <- Dead node + * (U8 qparam2) (U8 qparam2) (FP32) * \ / * \ / * [CircleConcatenation] @@ -871,7 +1024,11 @@ void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataTy auto node = concat->arg(i); auto const_node = dynamic_cast<luci::CircleConst *>(node); if (const_node != nullptr) - quant_const(const_node, quant_type); + { + auto new_const = luci::clone(const_node); + quant_const(new_const, quant_type); + concat->values(i, new_const); + } } return; } @@ -884,20 +1041,6 @@ void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataTy if (node->opcode() == luci::CircleOpcode::CONCATENATION) continue; - // Skip if this input is used by other Ops - auto succs = loco::succs(node); - if (succs.size() != 1) - { - if (node->opcode() == luci::CircleOpcode::CIRCLECONST) - { - luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node); - quant_const(const_node, quant_type); - } - continue; - } - - assert(succs.find(concat) != succs.end()); - // Quantize constant values if (node->opcode() == luci::CircleOpcode::CIRCLECONST) { @@ -913,15 +1056,21 @@ void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataTy const auto scaling_factor = concat_qparam->scale[0]; const auto zerop = concat_qparam->zerop[0]; - quant_const_values(const_node, scaling_factor, zerop, quant_type); + auto new_const = luci::clone(const_node); + quant_const_values(new_const, scaling_factor, zerop, quant_type); + concat->values(i, new_const); + overwrite_quantparam(concat, new_const); } else { + const auto succs = loco::succs(node); + if (succs.size() > 1) + continue; + // Non-const input must have been quantized assert(node->quantparam() != nullptr); + overwrite_quantparam(concat, node); } - - overwrite_quantparam(concat, node); } } @@ -954,13 +1103,6 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g) circle_node->accept(&qb); } - // Quantize const inputs other than weights and bias - for (auto node : loco::active_nodes(loco::output_nodes(g))) - { - auto circle_node = loco::must_cast<luci::CircleNode *>(node); - quantize_const_inputs(circle_node, _output_dtype, _granularity); - } - // Propagate quantization parameters of concat Op for (auto node : loco::active_nodes(loco::output_nodes(g))) { @@ -976,6 +1118,13 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g) propagate_concat_quantparam(concat, _output_dtype); } + // Quantize const inputs other than weights and bias + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + quantize_const_inputs(circle_node, _output_dtype); + } + // Update output dtype auto graph_outputs = g->outputs(); for (auto node : loco::output_nodes(g)) diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp new file mode 100644 index 000000000..75ec0cfd8 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/QuantizeWithMinMaxPass.h" + +#include <gtest/gtest.h> + +TEST(QuantizeWithMinMaxPassTest, name) +{ + luci::QuantizeWithMinMaxPass pass(loco::DataType::FLOAT32, loco::DataType::U8, + luci::QuantizationGranularity::LayerWise); + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.cpp new file mode 100644 index 000000000..5ea803cc9 --- /dev/null +++ b/compiler/luci/pass/src/QuantizedModelVerifier.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizedModelVerifier.h" + +#include "VerifyQuantizedNodeLayerWiseGranularity.h" +#include "VerifyQuantizedNodeChannelWiseGranularity.h" +#include "VerifyQuantizedNodeU8Type.h" +#include "VerifyQuantizedNodeS16Type.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> + +namespace luci +{ + +void QuantizedModelVerifier::verify(loco::Graph *g) +{ + if (_quantized_dtype != Type::U8 && _quantized_dtype != Type::S16) + throw std::runtime_error("Unsupported quantized dtype"); + + if (_granularity != Granularity::ChannelWise && _granularity != Granularity::LayerWise) + throw std::runtime_error("Unsupported granularity"); + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + + // Verify Type + if (_quantized_dtype == Type::U8) + { + VerifyQuantizedNodeU8Type vt; + if (!circle_node->accept(&vt)) + throw std::runtime_error("Wrong data type"); + } + else if (_quantized_dtype == Type::S16) + { + VerifyQuantizedNodeS16Type vt; + if (!circle_node->accept(&vt)) + throw std::runtime_error("Wrong data type"); + } + + // Verify Granularity + if (_granularity == Granularity::LayerWise) + { + VerifyQuantizedNodeLayerWiseGranularity vg; + if (!circle_node->accept(&vg)) + throw std::runtime_error("Wrong granularity"); + } + else if (_granularity == Granularity::ChannelWise) + { + VerifyQuantizedNodeChannelWiseGranularity vg; + if (!circle_node->accept(&vg)) + throw std::runtime_error("Wrong granularity"); + } + } +} + +} // namespace luci diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.h b/compiler/luci/pass/src/QuantizedModelVerifier.h new file mode 100644 index 000000000..d5fbb8e74 --- /dev/null +++ b/compiler/luci/pass/src/QuantizedModelVerifier.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_QUANTIZED_MODEL_VERIFIER_H__ +#define __LUCI_QUANTIZED_MODEL_VERIFIER_H__ + +#include "luci/Pass/QuantizationParameters.h" + +#include <loco.h> + +namespace luci +{ + +/** + * @brief Class to verify quantized model + * + * TODO Move this to luci/service + */ +struct QuantizedModelVerifier +{ + +public: + QuantizedModelVerifier(loco::DataType quantized_dtype, QuantizationGranularity granularity) + : _quantized_dtype(quantized_dtype), _granularity(granularity) + { + } + + void verify(loco::Graph *g); + +private: + loco::DataType _quantized_dtype; + QuantizationGranularity _granularity; +}; + +} // namespace luci + +#endif // __LUCI_QUANTIZED_MODEL_VERIFIER_H__ diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp new file mode 100644 index 000000000..eae1b0c1f --- /dev/null +++ b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp @@ -0,0 +1,1668 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizedModelVerifier.h" + +#include "luci/Pass/QuantizeWithMinMaxPass.h" + +#include <luci/test/TestIOGraph.h> + +#include <gtest/gtest.h> + +using Type = loco::DataType; +using Granularity = luci::QuantizationGranularity; + +namespace +{ + +/** + * @brief A helper function to create dummy const node + */ +template <Type T> luci::CircleConst *create_dummy_const(loco::Graph *g, luci::test::ShapeU32 shape) +{ + auto node = g->nodes()->create<luci::CircleConst>(); + { + node->dtype(T); + node->shape(shape); + node->size<T>(luci::test::num_elements(shape)); + + for (int32_t i = 0; i < luci::test::num_elements(shape); i++) + { + // DESIGN NOTE + // + // Filling with any random numbers are fine + // Q. Should it include minus numbers? + switch (T) + { + case Type::FLOAT32: + // Fill with index + node->at<T>(i) = static_cast<float>(i); + break; + case Type::BOOL: + // Fill by flip + node->at<T>(i) = (i % 2) ? true : false; + break; + case Type::U8: + // Fill with index + node->at<T>(i) = static_cast<uint8_t>(i); + break; + case Type::S16: + // Fill with index + node->at<T>(i) = static_cast<int16_t>(i); + break; + } + } + } + + return node; +} + +/** + * @brief A helper function to create const node with value + */ +template <Type DT, typename T> +luci::CircleConst *create_const(loco::Graph *g, luci::test::ShapeU32 shape, + std::initializer_list<T> values) +{ + auto node = g->nodes()->create<luci::CircleConst>(); + { + node->dtype(DT); + node->shape(shape); + node->size<DT>(luci::test::num_elements(shape)); + + assert(values.size() == node->size<DT>()); + + uint32_t index = 0; + for (auto val : values) + { + node->at<DT>(index++) = static_cast<T>(val); + } + } + + return node; +} + +void insert_scale_zp(luci::CircleNode *node, float scale, int64_t zp) +{ + auto qparam = node->quantparam(); + assert(qparam != nullptr); // FIX_CALLER_UNLESS + qparam->scale.push_back(scale); + qparam->zerop.push_back(zp); +} + +void quantize_and_verify(loco::Graph *g, Type quantized_dtype, Granularity granularity) +{ + luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity); + pass.run(g); + + luci::QuantizedModelVerifier verifier(quantized_dtype, granularity); + verifier.verify(g); +} + +// Helper function to reduce duplicate test codes +// Assumption: g->output()->from() is the target node +void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype, + Granularity granularity, Type wrong_dtype) +{ + luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity); + pass.run(g->g()); + + auto node = loco::must_cast<luci::CircleNode *>(g->output()->from()); + node->dtype(wrong_dtype); + + luci::QuantizedModelVerifier verifier(quantized_dtype, granularity); + verifier.verify(g->g()); +} + +// Helper function to reduce duplicate test codes +// Assumption: g->output()->from() is the target node +void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype, + Granularity granularity) +{ + luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity); + pass.run(g->g()); + + auto node = loco::must_cast<luci::CircleNode *>(g->output()->from()); + insert_scale_zp(node, 1.0, 1); + + luci::QuantizedModelVerifier verifier(quantized_dtype, granularity); + verifier.verify(g->g()); +} + +// Helper function to reduce duplicate test codes +void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype, + Granularity granularity, luci::CircleNode *target) +{ + luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity); + pass.run(g->g()); + + insert_scale_zp(target, 1.0, 1); + + luci::QuantizedModelVerifier verifier(quantized_dtype, granularity); + verifier.verify(g->g()); +} + +// Set min/max for all non-const nodes in the graph +void set_minmax_to_non_const(loco::Graph *g, float min, float max) +{ + for (auto node : loco::all_nodes(g)) + { + auto const_node = dynamic_cast<luci::CircleConst *>(node); + if (const_node != nullptr) + continue; + + // Min/Max is not recorded for ArgMax + // See MinMaxObserver.cpp in record_minmax module + auto argmax_node = dynamic_cast<luci::CircleArgMax *>(node); + if (argmax_node != nullptr) + continue; + + // Min/Max is not recorded for Split + // See MinMaxObserver.cpp in record_minmax module + auto split_node = dynamic_cast<luci::CircleSplit *>(node); + if (split_node != nullptr) + continue; + + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + auto qparam = std::make_unique<luci::CircleQuantParam>(); + { + qparam->min.emplace_back(min); + qparam->max.emplace_back(max); + } + circle_node->quantparam(std::move(qparam)); + } +} + +/** + * @brief Simple Test Graph + * @note + * The simple test graph's nodes are initialized with + * simple shapes and values. + */ +class SimpleTestGraph : public luci::test::TestIOGraph +{ +public: + virtual void init(void) = 0; +}; + +class InstanceNormTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _gamma = create_dummy_const<Type::FLOAT32>(g(), {32}); + _beta = create_dummy_const<Type::FLOAT32>(g(), {32}); + _instnorm = g()->nodes()->create<luci::CircleInstanceNorm>(); + { + _instnorm->input(input()); + _instnorm->gamma(_gamma); + _instnorm->beta(_beta); + } + output()->from(_instnorm); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + loco::Node *gamma(void) const { return _instnorm->gamma(); } + loco::Node *beta(void) const { return _instnorm->beta(); } + +public: + luci::CircleInstanceNorm *_instnorm = nullptr; + luci::CircleConst *_input = nullptr; + luci::CircleConst *_gamma = nullptr; + luci::CircleConst *_beta = nullptr; +}; + +class LogisticTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _logistic = g()->nodes()->create<luci::CircleLogistic>(); + { + _logistic->x(input()); + } + output()->from(_logistic); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleLogistic *_logistic = nullptr; +}; + +class SoftmaxTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _softmax = g()->nodes()->create<luci::CircleSoftmax>(); + { + _softmax->logits(input()); + _softmax->beta(0.1); + } + output()->from(_softmax); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleSoftmax *_softmax = nullptr; +}; + +class SpaceToBatchNDTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({1, 2, 2, 1}, {4, 1, 1, 1}); + _block_shape = create_dummy_const<Type::S32>(g(), {2}); + for (uint32_t i = 0; i < 2; i++) + _block_shape->at<Type::S32>(i) = 2; + + _paddings = create_dummy_const<Type::S32>(g(), {2, 2}); + for (uint32_t i = 0; i < 4; i++) + _paddings->at<Type::S32>(i) = 0; + + _stob = g()->nodes()->create<luci::CircleSpaceToBatchND>(); + { + _stob->input(input()); + _stob->block_shape(_block_shape); + _stob->paddings(_paddings); + } + output()->from(_stob); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleSpaceToBatchND *_stob = nullptr; + luci::CircleConst *_block_shape = nullptr; + luci::CircleConst *_paddings = nullptr; +}; + +class SpaceToDepthTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({1, 2, 2, 1}, {1, 1, 1, 4}); + _stod = g()->nodes()->create<luci::CircleSpaceToDepth>(); + { + _stod->input(input()); + _stod->block_size(2); + } + output()->from(_stod); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleSpaceToDepth *_stod = nullptr; +}; + +template <Type indexT> class SliceTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _begin = g()->nodes()->create<luci::CircleConst>(); + { + _begin->dtype(indexT); + } + _size = g()->nodes()->create<luci::CircleConst>(); + { + _size->dtype(indexT); + } + _slice = g()->nodes()->create<luci::CircleSlice>(); + { + _slice->input(input()); + _slice->begin(_begin); + _slice->size(_size); + } + output()->from(_slice); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleSlice *_slice = nullptr; + luci::CircleConst *_begin = nullptr; + luci::CircleConst *_size = nullptr; +}; + +class SplitTestGraph final : public luci::test::TestIOGraph +{ +public: + void init(void) + { + TestIOGraph::init({1, 32}, {32}); + _split_dim = create_dummy_const<Type::S32>(g(), {1}); + _split = g()->nodes()->create<luci::CircleSplit>(); + { + _split->input(input()); + _split->split_dim(_split_dim); + } + _split_o1 = g()->nodes()->create<luci::CircleSplitOut>(); + { + _split_o1->input(_split); + _split_o1->index(0); + } + + output()->from(_split_o1); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleSplit *_split = nullptr; + luci::CircleSplitOut *_split_o1 = nullptr; + luci::CircleConst *_split_dim = nullptr; +}; + +class StridedSliceTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _begin = g()->nodes()->create<luci::CircleConst>(); + { + _begin->dtype(Type::S32); + } + _end = g()->nodes()->create<luci::CircleConst>(); + { + _end->dtype(Type::S32); + } + _strides = g()->nodes()->create<luci::CircleConst>(); + { + _strides->dtype(Type::S32); + } + _slice = g()->nodes()->create<luci::CircleStridedSlice>(); + { + _slice->input(input()); + _slice->begin(_begin); + _slice->end(_end); + _slice->strides(_strides); + } + output()->from(_slice); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleStridedSlice *_slice = nullptr; + luci::CircleConst *_begin = nullptr; + luci::CircleConst *_end = nullptr; + luci::CircleConst *_strides = nullptr; +}; + +class ReshapeTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _shape = g()->nodes()->create<luci::CircleConst>(); + { + _shape->dtype(Type::S32); + } + _reshape = g()->nodes()->create<luci::CircleReshape>(); + { + _reshape->tensor(input()); + _reshape->shape(_shape); + } + output()->from(_reshape); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleReshape *_reshape = nullptr; + luci::CircleConst *_shape = nullptr; +}; + +class TanhTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _tanh = g()->nodes()->create<luci::CircleTanh>(); + { + _tanh->x(input()); + } + output()->from(_tanh); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleTanh *_tanh = nullptr; +}; + +class FloorTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _floor = g()->nodes()->create<luci::CircleFloor>(); + { + _floor->x(input()); + } + output()->from(_floor); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleFloor *_floor = nullptr; +}; + +template <Type indexT> class ArgMaxTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {1}); + // output dtype is float by default, but ArgMax should have indexType (s32/s64) + output()->dtype(indexT); + _dimension = g()->nodes()->create<luci::CircleConst>(); + { + _dimension->dtype(indexT); + } + _argmax = g()->nodes()->create<luci::CircleArgMax>(); + { + _argmax->input(input()); + _argmax->dimension(_dimension); + _argmax->output_type(indexT); + _argmax->dtype(indexT); + } + output()->from(_argmax); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleArgMax *_argmax = nullptr; + luci::CircleConst *_dimension = nullptr; +}; + +class BatchToSpaceNDTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _block_shape = g()->nodes()->create<luci::CircleConst>(); + { + _block_shape->dtype(Type::S32); + } + _crops = g()->nodes()->create<luci::CircleConst>(); + { + _crops->dtype(Type::S32); + } + _btos = g()->nodes()->create<luci::CircleBatchToSpaceND>(); + { + _btos->input(input()); + _btos->block_shape(_block_shape); + _btos->crops(_crops); + } + output()->from(_btos); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleBatchToSpaceND *_btos = nullptr; + luci::CircleConst *_block_shape = nullptr; + luci::CircleConst *_crops = nullptr; +}; + +class DepthToSpaceTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({1, 1, 1, 4}, {1, 2, 2, 1}); + _dtos = g()->nodes()->create<luci::CircleDepthToSpace>(); + { + _dtos->input(input()); + _dtos->block_size(2); + } + output()->from(_dtos); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleDepthToSpace *_dtos = nullptr; +}; + +class PadTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _paddings = g()->nodes()->create<luci::CircleConst>(); + { + _paddings->dtype(Type::S32); + } + _pad = g()->nodes()->create<luci::CirclePad>(); + { + _pad->input(input()); + _pad->paddings(_paddings); + } + output()->from(_pad); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CirclePad *_pad = nullptr; + luci::CircleConst *_paddings = nullptr; +}; + +class TransposeTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _perm = g()->nodes()->create<luci::CircleConst>(); + { + _perm->dtype(Type::S32); + } + _transpose = g()->nodes()->create<luci::CircleTranspose>(); + { + _transpose->a(input()); + _transpose->perm(_perm); + } + output()->from(_transpose); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleTranspose *_transpose = nullptr; + luci::CircleConst *_perm = nullptr; +}; + +class ConcatenationTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({16}, {32}); + _param = create_dummy_const<Type::FLOAT32>(g(), {16}); + _concat = g()->nodes()->create<luci::CircleConcatenation>(2); + { + _concat->values(0, input()); + _concat->values(1, _param); + _concat->axis(0); + } + output()->from(_concat); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleConcatenation *_concat = nullptr; + luci::CircleConst *_param = nullptr; +}; + +// Test graph for comparison Ops +// GREATER, GREATER_EQUAL, LESS, LESS_EQUAL, EQUAL, NOT_EQUAL +template <class Op> class ComparisonOpTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + output()->dtype(loco::DataType::BOOL); + _y = create_dummy_const<Type::FLOAT32>(g(), {32}); + _op = g()->nodes()->create<Op>(); + { + _op->x(input()); + _op->y(_y); + _op->dtype(loco::DataType::BOOL); + } + output()->from(_op); + + set_minmax_to_non_const(g(), -1, 1); + } + + loco::Node *x(void) const { return _op->x(); } + loco::Node *y(void) const { return _op->y(); } + +public: + Op *_op = nullptr; + luci::CircleConst *_y = nullptr; +}; + +// Test graph for binary logical Ops +// LOGICAL_OR, LOGICAL_AND +template <class Op> class BinaryLogicalOpTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + input()->dtype(loco::DataType::BOOL); + output()->dtype(loco::DataType::BOOL); + _y = create_dummy_const<Type::BOOL>(g(), {32}); + _op = g()->nodes()->create<Op>(); + { + _op->x(input()); + _op->y(_y); + _op->dtype(loco::DataType::BOOL); + } + output()->from(_op); + + set_minmax_to_non_const(g(), -1, 1); + } + + loco::Node *x(void) const { return _op->x(); } + loco::Node *y(void) const { return _op->y(); } + +public: + Op *_op = nullptr; + luci::CircleConst *_y = nullptr; +}; + +class DivTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + + _const = create_dummy_const<Type::FLOAT32>(g(), {32}); + _div = g()->nodes()->create<luci::CircleDiv>(); + { + _div->x(input()); + _div->y(_const); + } + output()->from(_div); + + set_minmax_to_non_const(g(), -1, 1); + } + + loco::Node *x() { return _div->x(); } + + loco::Node *y() { return _div->y(); } + +private: + luci::CircleDiv *_div = nullptr; + luci::CircleConst *_const = nullptr; +}; + +class FloorDivTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + + _const = create_dummy_const<Type::FLOAT32>(g(), {32}); + _floor_div = g()->nodes()->create<luci::CircleFloorDiv>(); + { + _floor_div->x(input()); + _floor_div->y(_const); + } + output()->from(_floor_div); + + set_minmax_to_non_const(g(), -1, 1); + } + + loco::Node *x() { return _floor_div->x(); } + + loco::Node *y() { return _floor_div->y(); } + +private: + luci::CircleFloorDiv *_floor_div = nullptr; + luci::CircleConst *_const = nullptr; +}; + +class RsqrtTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _rsqrt = g()->nodes()->create<luci::CircleRsqrt>(); + { + _rsqrt->x(input()); + } + output()->from(_rsqrt); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleRsqrt *_rsqrt = nullptr; +}; + +class SqrtTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _sqrt = g()->nodes()->create<luci::CircleSqrt>(); + { + _sqrt->x(input()); + } + output()->from(_sqrt); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleSqrt *_sqrt = nullptr; +}; + +class EluTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + _elu = g()->nodes()->create<luci::CircleElu>(); + { + _elu->features(input()); + } + output()->from(_elu); + + set_minmax_to_non_const(g(), -1, 1); + } + +public: + luci::CircleElu *_elu = nullptr; +}; + +class PowTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + + _const = create_dummy_const<Type::FLOAT32>(g(), {32}); + _pow = g()->nodes()->create<luci::CirclePow>(); + { + _pow->x(input()); + _pow->y(_const); + } + output()->from(_pow); + + set_minmax_to_non_const(g(), -1, 1); + } + + loco::Node *x() { return _pow->x(); } + + loco::Node *y() { return _pow->y(); } + +private: + luci::CirclePow *_pow = nullptr; + luci::CircleConst *_const = nullptr; +}; + +class ResizeBilinearTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({1, 4, 4, 1}, {1, 8, 8, 1}); + + _size = create_const<Type::S32, int32_t>(g(), {2}, {8, 8}); + _resize_bilinear = g()->nodes()->create<luci::CircleResizeBilinear>(); + { + _resize_bilinear->input(input()); + _resize_bilinear->size(_size); + } + output()->from(_resize_bilinear); + + set_minmax_to_non_const(g(), -1, 1); + } + +private: + luci::CircleResizeBilinear *_resize_bilinear = nullptr; + luci::CircleConst *_size = nullptr; +}; + +} // namespace + +// Quantize and verify with given configurations +#define TEST_WITH_GRAPH(graph, type, granularity) \ + do \ + { \ + graph g; \ + g.init(); \ + EXPECT_NO_THROW(quantize_and_verify(g.g(), type, granularity)); \ + } while (0) + +// Quantize and verify with wrong type +#define TEST_WITH_WRONG_TYPE(graph, type, granularity, wrong_dtype) \ + do \ + { \ + graph g; \ + g.init(); \ + EXPECT_ANY_THROW(quantize_and_verify_with_wrong_type(&g, type, granularity, wrong_dtype)); \ + } while (0) + +// Quantize and verify with wrong granularity +#define TEST_WITH_WRONG_GRANULARITY(graph, type, granularity) \ + do \ + { \ + graph g; \ + g.init(); \ + EXPECT_ANY_THROW(quantize_and_verify_with_wrong_granularity(&g, type, granularity)); \ + } while (0) + +// Quantize and verify with wrong granularity +// Users can specify the test target +#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity, target) \ + do \ + { \ + graph g; \ + g.init(); \ + auto node = loco::must_cast<luci::CircleNode *>(target); \ + EXPECT_ANY_THROW(quantize_and_verify_with_wrong_granularity(&g, type, granularity, node)); \ + } while (0) + +// Test a local helper function +TEST(QuantizedModelVerifierTest, LocalCreateDummyConst) +{ + loco::Graph g; + + EXPECT_NO_THROW(create_dummy_const<Type::FLOAT32>(&g, {32, 32})); +} + +TEST(QuantizedModelVerifierTest, LocalCreateConst) +{ + loco::Graph g; + std::initializer_list<float> values = {0.1, 0, -5, 100}; + luci::CircleConst *node = create_const<Type::FLOAT32, float>(&g, {2, 2}, values); + + uint32_t index = 0; + for (auto val : values) + { + EXPECT_EQ(node->at<Type::FLOAT32>(index++), val); + } +} + +TEST(QuantizedModelVerifierTest, InstanceNorm) +{ + TEST_WITH_GRAPH(InstanceNormTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, InstanceNorm_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(InstanceNormTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, InstanceNorm_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(InstanceNormTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Logistic) +{ + TEST_WITH_GRAPH(LogisticTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(LogisticTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(LogisticTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Logistic_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(LogisticTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(LogisticTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(LogisticTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Logistic_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(LogisticTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(LogisticTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(LogisticTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Softmax) +{ + TEST_WITH_GRAPH(SoftmaxTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Softmax_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(SoftmaxTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Softmax_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(SoftmaxTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, SpaceToBatchND) +{ + TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, SpaceToBatchND_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, SpaceToBatchND_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, SpaceToDepth) +{ + TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, SpaceToDepth_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, SpaceToDepth_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Slice) +{ + TEST_WITH_GRAPH(SliceTestGraph<Type::S32>, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(SliceTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(SliceTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise); + + TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Slice_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(SliceTestGraph<Type::S32>, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(SliceTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(SliceTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise, Type::U8); + + TEST_WITH_WRONG_TYPE(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Slice_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(SliceTestGraph<Type::S32>, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(SliceTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(SliceTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise); + + TEST_WITH_WRONG_GRANULARITY(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Split) +{ + TEST_WITH_GRAPH(SplitTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(SplitTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(SplitTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Split_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(SplitTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(SplitTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(SplitTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Split_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(SplitTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(SplitTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(SplitTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, StridedSlice) +{ + TEST_WITH_GRAPH(StridedSliceTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, StridedSlice_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(StridedSliceTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, StridedSlice_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(StridedSliceTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, ArgMax) +{ + TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S32>, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise); + + TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S64>, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, ArgMax_wrong_dimension_type_NEG) +{ + ArgMaxTestGraph<Type::S32> g; + g.init(); + luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, Type::U8, Granularity::LayerWise); + pass.run(g.g()); + + g._dimension->dtype(Type::U8); + + luci::QuantizedModelVerifier verifier(Type::U8, Granularity::LayerWise); + EXPECT_ANY_THROW(verifier.verify(g.g())); +} + +TEST(QuantizedModelVerifierTest, ArgMax_wrong_input_granularity_NEG) +{ + ArgMaxTestGraph<Type::S32> g; + g.init(); + + luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, Type::U8, Granularity::LayerWise); + pass.run(g.g()); + + insert_scale_zp(loco::must_cast<luci::CircleNode *>(g._argmax->input()), 1.0, 1); + + luci::QuantizedModelVerifier verifier(Type::U8, Granularity::LayerWise); + EXPECT_ANY_THROW(verifier.verify(g.g())); +} + +TEST(QuantizedModelVerifierTest, BatchToSpaceND) +{ + TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, BatchToSpaceND_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, BatchToSpaceND_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, DepthToSpace) +{ + TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, DepthToSpace_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, DepthToSpace_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Concatenation) +{ + TEST_WITH_GRAPH(ConcatenationTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Concatenation_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(ConcatenationTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Concatenation_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(ConcatenationTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, LogicalOr) +{ + TEST_WITH_GRAPH(BinaryLogicalOpTestGraph<luci::CircleLogicalOr>, Type::U8, + Granularity::LayerWise); + TEST_WITH_GRAPH(BinaryLogicalOpTestGraph<luci::CircleLogicalOr>, Type::U8, + Granularity::ChannelWise); + TEST_WITH_GRAPH(BinaryLogicalOpTestGraph<luci::CircleLogicalOr>, Type::S16, + Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, LogicalOr_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(BinaryLogicalOpTestGraph<luci::CircleLogicalOr>, Type::U8, + Granularity::LayerWise, Type::U8); + TEST_WITH_WRONG_TYPE(BinaryLogicalOpTestGraph<luci::CircleLogicalOr>, Type::U8, + Granularity::ChannelWise, Type::U8); + TEST_WITH_WRONG_TYPE(BinaryLogicalOpTestGraph<luci::CircleLogicalOr>, Type::S16, + Granularity::ChannelWise, Type::S16); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Reshape) +{ + TEST_WITH_GRAPH(ReshapeTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(ReshapeTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(ReshapeTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Reshape_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(ReshapeTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(ReshapeTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(ReshapeTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Reshape_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(ReshapeTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(ReshapeTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(ReshapeTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Tanh) +{ + TEST_WITH_GRAPH(TanhTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(TanhTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(TanhTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Tanh_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(TanhTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(TanhTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(TanhTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Tanh_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(TanhTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(TanhTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(TanhTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Pad) +{ + TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(PadTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Pad_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(PadTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(PadTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(PadTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Pad_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(PadTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(PadTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(PadTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Transpose) +{ + TEST_WITH_GRAPH(TransposeTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(TransposeTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(TransposeTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Transpose_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(TransposeTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(TransposeTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(TransposeTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Transpose_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(TransposeTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(TransposeTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(TransposeTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Floor) +{ + TEST_WITH_GRAPH(FloorTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(FloorTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(FloorTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Floor_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(FloorTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(FloorTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(FloorTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Floor_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(FloorTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(FloorTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(FloorTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, GreaterEqual) +{ + TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8, + Granularity::LayerWise); + TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8, + Granularity::ChannelWise); + TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::S16, + Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, GreaterEqual_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8, + Granularity::LayerWise, Type::U8); + TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8, + Granularity::ChannelWise, Type::U8); + TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::S16, + Granularity::ChannelWise, Type::S16); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, GreaterEqual_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8, + Granularity::LayerWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8, + Granularity::ChannelWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::S16, + Granularity::ChannelWise, g.x()); + + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8, + Granularity::LayerWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8, + Granularity::ChannelWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::S16, + Granularity::ChannelWise, g.y()); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Greater) +{ + TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleGreater>, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Greater_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8, Granularity::LayerWise, + Type::U8); + TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8, + Granularity::ChannelWise, Type::U8); + TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleGreater>, Type::S16, + Granularity::ChannelWise, Type::S16); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Greater_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8, + Granularity::LayerWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8, + Granularity::ChannelWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreater>, Type::S16, + Granularity::ChannelWise, g.x()); + + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8, + Granularity::LayerWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8, + Granularity::ChannelWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreater>, Type::S16, + Granularity::ChannelWise, g.y()); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, NotEqual) +{ + TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, NotEqual_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8, + Granularity::LayerWise, Type::U8); + TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8, + Granularity::ChannelWise, Type::U8); + TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::S16, + Granularity::ChannelWise, Type::S16); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, NotEqual_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8, + Granularity::LayerWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8, + Granularity::ChannelWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::S16, + Granularity::ChannelWise, g.x()); + + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8, + Granularity::LayerWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8, + Granularity::ChannelWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::S16, + Granularity::ChannelWise, g.y()); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Div) +{ + TEST_WITH_GRAPH(DivTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(DivTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(DivTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Div_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(DivTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(DivTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(DivTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Div_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY_TARGET(DivTestGraph, Type::U8, Granularity::LayerWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(DivTestGraph, Type::U8, Granularity::ChannelWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(DivTestGraph, Type::S16, Granularity::ChannelWise, g.x()); + + TEST_WITH_WRONG_GRANULARITY_TARGET(DivTestGraph, Type::U8, Granularity::LayerWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(DivTestGraph, Type::U8, Granularity::ChannelWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(DivTestGraph, Type::S16, Granularity::ChannelWise, g.y()); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, FloorDiv) +{ + TEST_WITH_GRAPH(FloorDivTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(FloorDivTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(FloorDivTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, FloorDiv_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(FloorDivTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(FloorDivTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(FloorDivTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, FloorDiv_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY_TARGET(FloorDivTestGraph, Type::U8, Granularity::LayerWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(FloorDivTestGraph, Type::U8, Granularity::ChannelWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(FloorDivTestGraph, Type::S16, Granularity::ChannelWise, g.x()); + + TEST_WITH_WRONG_GRANULARITY_TARGET(FloorDivTestGraph, Type::U8, Granularity::LayerWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(FloorDivTestGraph, Type::U8, Granularity::ChannelWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(FloorDivTestGraph, Type::S16, Granularity::ChannelWise, g.y()); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Rsqrt) +{ + TEST_WITH_GRAPH(RsqrtTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(RsqrtTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(RsqrtTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Rsqrt_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(RsqrtTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(RsqrtTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(RsqrtTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Rsqrt_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(RsqrtTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(RsqrtTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(RsqrtTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Sqrt) +{ + TEST_WITH_GRAPH(SqrtTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(SqrtTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(SqrtTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Sqrt_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(SqrtTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(SqrtTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(SqrtTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Sqrt_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(SqrtTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(SqrtTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(SqrtTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Elu) +{ + TEST_WITH_GRAPH(EluTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(EluTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(EluTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Elu_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(EluTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(EluTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(EluTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Elu_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(EluTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(EluTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(EluTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Pow) +{ + TEST_WITH_GRAPH(PowTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(PowTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(PowTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Pow_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(PowTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(PowTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(PowTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Pow_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY_TARGET(PowTestGraph, Type::U8, Granularity::LayerWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(PowTestGraph, Type::U8, Granularity::ChannelWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(PowTestGraph, Type::S16, Granularity::ChannelWise, g.x()); + + TEST_WITH_WRONG_GRANULARITY_TARGET(PowTestGraph, Type::U8, Granularity::LayerWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(PowTestGraph, Type::U8, Granularity::ChannelWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(PowTestGraph, Type::S16, Granularity::ChannelWise, g.y()); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, ResizeBilinear) +{ + TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, ResizeBilinear_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, ResizeBilinear_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +#undef TEST_WITH_GRAPH +#undef TEST_WITH_WRONG_TYPE +#undef TEST_WITH_WRONG_GRANULARITY diff --git a/compiler/luci/pass/src/RemoveRedundantReshape.cpp b/compiler/luci/pass/src/RemoveRedundantReshape.cpp new file mode 100644 index 000000000..2f0b22ae6 --- /dev/null +++ b/compiler/luci/pass/src/RemoveRedundantReshape.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveRedundantReshapePass.h" + +#include <luci/IR/CircleNodes.h> + +namespace +{ + +bool remove_redundant_reshape(luci::CircleReshape *node) +{ + auto pred_node = dynamic_cast<luci::CircleReshape *>(node->tensor()); + if (pred_node == nullptr) + return false; + + node->tensor(pred_node->tensor()); + return true; +} + +} // namespace + +namespace luci +{ + +/** + * BEFORE + * + * [CircleNode] + * | + * [CircleReshape_1] + * | + * [CircleReshape_2] + * | + * [CircleNode] + * + * AFTER + * + * [CircleNode] + * / \ + * [CircleReshape_1] [CircleReshape_2] + * | + * [CircleNode] + **/ +bool RemoveRedundantReshapePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto reshape_node = dynamic_cast<luci::CircleReshape *>(node)) + { + if (remove_redundant_reshape(reshape_node)) + changed = true; + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveRedundantReshape.test.cpp b/compiler/luci/pass/src/RemoveRedundantReshape.test.cpp new file mode 100644 index 000000000..617840f3a --- /dev/null +++ b/compiler/luci/pass/src/RemoveRedundantReshape.test.cpp @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "luci/Pass/RemoveRedundantReshapePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +class RemoveRedundantReshape : public ::testing::Test +{ +public: + RemoveRedundantReshape() {} + + void createReshapeConst(luci::CircleReshape *target, const std::vector<int32_t> shape) + { + auto shape_const = g.nodes()->create<luci::CircleConst>(); + shape_const->dtype(loco::DataType::S32); + shape_const->size<loco::DataType::S32>(shape.size()); + shape_const->shape_status(luci::ShapeStatus::VALID); + shape_const->rank(1); + shape_const->dim(0).set(shape.size()); + for (int32_t i = 0; i < shape.size(); i++) + { + shape_const->at<loco::DataType::S32>(i) = shape.at(i); + } + shape_const->name("shape_const"); + target->shape(shape_const); + } + + void buildGraph(const std::initializer_list<uint32_t> base_shape, + const std::vector<int32_t> first_shape, const std::vector<int32_t> second_shape) + { + // Input Create. + input = g.nodes()->create<luci::CircleInput>(); + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + input->shape_status(luci::ShapeStatus::VALID); + input->rank(base_shape.size()); + input->shape(base_shape); + input->name("input"); + + // Create first reshape. + first_reshape = g.nodes()->create<luci::CircleReshape>(); + first_reshape->tensor(input); + first_reshape->name("Reshape"); + createReshapeConst(first_reshape, first_shape); + + // Create second reshape. + second_reshape = g.nodes()->create<luci::CircleReshape>(); + second_reshape->tensor(first_reshape); + second_reshape->name("second_reshape"); + createReshapeConst(second_reshape, second_shape); + + // Output Connect. + output = g.nodes()->create<luci::CircleOutput>(); + output->from(second_reshape); + output->name("output"); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + } + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleReshape *first_reshape = nullptr; + luci::CircleReshape *second_reshape = nullptr; + luci::CircleOutput *output = nullptr; +}; + +} // namespace + +TEST(RemoveRedundantReshapePassTest, name) +{ + luci::RemoveRedundantReshapePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(RemoveRedundantReshape, simple_case) +{ + buildGraph({4, 6}, {-1, 4, 6}, {1, -1, 2, 3}); + luci::RemoveRedundantReshapePass pass; + while (pass.run(&g)) + ; + int count = 0; + for (auto node : loco::active_nodes(loco::output_nodes(&g))) + { + if (auto reshape = dynamic_cast<luci::CircleReshape *>(node)) + { + count++; + } + } + ASSERT_EQ(1, count); +} diff --git a/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp b/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp deleted file mode 100644 index db608b674..000000000 --- a/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "luci/Pass/RemoveRedundantTransposePass.h" - -#include <luci/IR/CircleNodes.h> - -#include <vector> - -#include <gtest/gtest.h> - -namespace -{ - -void setValue(luci::CircleConst *node, const std::vector<int> &v) -{ - node->dtype(loco::DataType::S32); - node->size<loco::DataType::S32>(v.size()); - node->rank(1); - node->dim(0).set(v.size()); - for (int i = 0; i < v.size(); ++i) - { - node->at<loco::DataType::S32>(i) = v[i]; - } -} - -/** - * Type1 - * BEFORE - * | - * [CircleNode] [CircleConst] - * \ / - * [CircleTranspose] [CircleConst] - * \ / - * [CircleTranspose] - * | - * - * AFTER - * | - * [CircleNode] - * | Remove Both - * - * -------------------------------------------- - * - * Type2 - * BEFORE - * | - * [CircleNode] [CircleConst] - * \ / - * [CircleTranspose] [CircleConst] - * \ / - * [CircleTranspose] - * | - * - * AFTER - * | | - * [CircleNode] [CircleConst] - * \ / - * [CircleTranspose] - * | - * - */ -void create_redundunt_transpose(loco::Graph *g, const std::vector<int32_t> &perm1, - const std::vector<int32_t> &perm2) -{ - assert(g); - - auto input = g->nodes()->create<luci::CircleInput>(); - auto graph_input = g->inputs()->create(); - input->index(graph_input->index()); - - // Create perm1 - auto perm1_node = g->nodes()->create<luci::CircleConst>(); - setValue(perm1_node, perm1); - - auto transpose1 = g->nodes()->create<luci::CircleTranspose>(); - transpose1->dtype(loco::DataType::FLOAT32); - transpose1->a(input); - transpose1->perm(perm1_node); - - // Create perm2 - auto perm2_node = g->nodes()->create<luci::CircleConst>(); - setValue(perm2_node, perm2); - - auto transpose2 = g->nodes()->create<luci::CircleTranspose>(); - transpose2->dtype(loco::DataType::FLOAT32); - transpose2->a(transpose1); - transpose2->perm(perm2_node); - - // Output - auto output = g->nodes()->create<luci::CircleOutput>(); - output->from(transpose2); - auto graph_output = g->outputs()->create(); - output->index(graph_output->index()); -} - -} // namespace - -TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type1) -{ - auto graph = loco::make_graph(); - create_redundunt_transpose(graph.get(), {1, 0, 2, 3}, {1, 0, 2, 3}); - - luci::RemoveRedundantTransposePass pass; - while (pass.run(graph.get())) - ; - luci::CircleTranspose *transpose_node = nullptr; - for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) - { - auto trans = dynamic_cast<luci::CircleTranspose *>(node); - if (not trans) - continue; - transpose_node = trans; - break; - } - // No transpose node is in graph. - ASSERT_EQ(nullptr, transpose_node); -} - -TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2) -{ - auto graph = loco::make_graph(); - create_redundunt_transpose(graph.get(), {0, 1, 3, 2}, {1, 0, 2, 3}); - - luci::RemoveRedundantTransposePass pass; - while (pass.run(graph.get())) - ; - luci::CircleTranspose *transpose_node = nullptr; - for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) - { - auto trans = dynamic_cast<luci::CircleTranspose *>(node); - if (not trans) - continue; - transpose_node = trans; - break; - } - // Just one transpose node, with updated perm constant. - ASSERT_NE(nullptr, transpose_node); - auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm()); - ASSERT_EQ(1, perm->at<loco::DataType::S32>(0)); - ASSERT_EQ(0, perm->at<loco::DataType::S32>(1)); - ASSERT_EQ(3, perm->at<loco::DataType::S32>(2)); - ASSERT_EQ(2, perm->at<loco::DataType::S32>(3)); -} diff --git a/compiler/luci/pass/src/RemoveRedundantTranspose.cpp b/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp index 33cb76520..71c51ecda 100644 --- a/compiler/luci/pass/src/RemoveRedundantTranspose.cpp +++ b/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp @@ -17,6 +17,7 @@ #include "luci/Pass/RemoveRedundantTransposePass.h" #include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> namespace { @@ -35,47 +36,54 @@ bool check_perm(const luci::CircleConst *first_perm, const luci::CircleConst *se return true; } -bool remove_consecutive_transpose_function(luci::CircleNode *node) +bool remove_consecutive_transpose_function(luci::CircleTranspose *target_node) { - auto target_node = dynamic_cast<luci::CircleTranspose *>(node); - if (target_node == nullptr) - return false; auto pred_node = dynamic_cast<luci::CircleTranspose *>(target_node->a()); if (pred_node == nullptr) return false; - if (loco::succs(pred_node).size() != 1) - return false; - auto pred_perm = dynamic_cast<luci::CircleConst *>(target_node->perm()); - if (pred_perm == nullptr) + auto target_perm = dynamic_cast<luci::CircleConst *>(target_node->perm()); + if (target_perm == nullptr) return false; - auto main_perm = dynamic_cast<luci::CircleConst *>(pred_node->perm()); - if (main_perm == nullptr) + auto pred_perm = dynamic_cast<luci::CircleConst *>(pred_node->perm()); + if (pred_perm == nullptr) return false; auto main_node = loco::must_cast<luci::CircleNode *>(pred_node->a()); - if (check_perm(pred_perm, main_perm)) + if (check_perm(target_perm, pred_perm)) { - replace(node).with(main_node); + replace(target_node).with(main_node); } else { - auto g = main_perm->graph(); + auto name = target_node->name(); + assert(name.length() > 0); + + auto g = pred_perm->graph(); auto new_const_node = g->nodes()->create<luci::CircleConst>(); new_const_node->dtype(loco::DataType::S32); new_const_node->rank(1); - new_const_node->dim(0) = main_perm->dim(0); - new_const_node->size<loco::DataType::S32>(main_perm->dim(0).value()); + new_const_node->dim(0) = pred_perm->dim(0); + new_const_node->size<loco::DataType::S32>(pred_perm->dim(0).value()); new_const_node->shape_status(luci::ShapeStatus::VALID); - for (uint32_t i = 0; i < main_perm->size<loco::DataType::S32>(); i++) + for (uint32_t i = 0; i < pred_perm->size<loco::DataType::S32>(); i++) { new_const_node->at<loco::DataType::S32>(i) = - pred_perm->at<loco::DataType::S32>(main_perm->at<loco::DataType::S32>(i)); + target_perm->at<loco::DataType::S32>(pred_perm->at<loco::DataType::S32>(i)); } - pred_node->perm(new_const_node); - replace(node).with(pred_node); + new_const_node->name(name + "/Transpose/perm"); + + // Create New Transpose Node + auto new_transpose_node = g->nodes()->create<luci::CircleTranspose>(); + new_transpose_node->dtype(target_node->dtype()); + new_transpose_node->a(main_node); + new_transpose_node->perm(new_const_node); + new_transpose_node->name(name + "/Transpose"); + luci::add_origin(new_transpose_node, luci::get_origin(target_node)); + + replace(target_node).with(new_transpose_node); } return true; } @@ -84,41 +92,36 @@ bool remove_consecutive_transpose_function(luci::CircleNode *node) namespace luci { + /** * BEFORE * | * [CircleNode] [CircleConst] - * (main_node) (main_perm) - * \ / + * | (pred_perm) + * \ / * [CircleTranspose] [CircleConst] - * (pred_node) (pred_perm) + * (pred_node) (target_perm) * \ / * [CircleTranspose] * (target_node) * | * * AFTER - * <Optional Case> - * - * | | | - * [CircleNode] [CircleConst] | - * (main_node) (new_const_node) | - * \ / or [CircleNode] - * [CircleTranspose] (main_node) - * (pred_node) | + * | | + * [CircleNode] [CircleConst](new) | + * \ / or [CircleNode] + * [CircleTranspose](new) | * | | - * */ bool RemoveRedundantTransposePass::run(loco::Graph *g) { bool changed = false; for (auto node : loco::active_nodes(loco::output_nodes(g))) { - auto circle_node = loco::must_cast<luci::CircleNode *>(node); - if (remove_consecutive_transpose_function(circle_node)) + if (auto transpose = dynamic_cast<luci::CircleTranspose *>(node)) { - changed = true; - break; + if (remove_consecutive_transpose_function(transpose)) + changed = true; } } return changed; diff --git a/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp new file mode 100644 index 000000000..e80623499 --- /dev/null +++ b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp @@ -0,0 +1,321 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "luci/Pass/RemoveRedundantTransposePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <vector> + +#include <gtest/gtest.h> + +namespace +{ + +void setValue(luci::CircleConst *node, const std::vector<int> &v) +{ + node->dtype(loco::DataType::S32); + node->size<loco::DataType::S32>(v.size()); + node->rank(1); + node->dim(0).set(v.size()); + for (int i = 0; i < v.size(); ++i) + { + node->at<loco::DataType::S32>(i) = v[i]; + } +} + +/** + * Remove for consecutive Transpose + * + * Type1: Remove both Transpose + * BEFORE + * | + * [CircleNode] [CircleConst] + * \ / + * [CircleTranspose] [CircleConst] + * \ / + * [CircleTranspose] + * | + * + * AFTER + * | + * [CircleNode] + * | + * + * -------------------------------------------- + * + * Type2: Merge to one Transpose + * BEFORE + * | + * [CircleNode] [CircleConst] + * \ / + * [CircleTranspose] [CircleConst] + * \ / + * [CircleTranspose] + * | + * + * AFTER + * | + * [CircleNode] [CircleConst] + * \ / + * [CircleTranspose] + * | + * + */ +void create_redundunt_transpose(loco::Graph *g, const std::vector<int32_t> &perm1, + const std::vector<int32_t> &perm2) +{ + assert(g); + + auto input = g->nodes()->create<luci::CircleInput>(); + auto graph_input = g->inputs()->create(); + input->index(graph_input->index()); + input->name("input"); + + // Create perm1 + auto perm1_node = g->nodes()->create<luci::CircleConst>(); + setValue(perm1_node, perm1); + perm1_node->name("perm1_node"); + + auto transpose1 = g->nodes()->create<luci::CircleTranspose>(); + transpose1->dtype(loco::DataType::FLOAT32); + transpose1->a(input); + transpose1->perm(perm1_node); + transpose1->name("transpose1"); + + // Create perm2 + auto perm2_node = g->nodes()->create<luci::CircleConst>(); + setValue(perm2_node, perm2); + perm2_node->name("perm2_node"); + + auto transpose2 = g->nodes()->create<luci::CircleTranspose>(); + transpose2->dtype(loco::DataType::FLOAT32); + transpose2->a(transpose1); + transpose2->perm(perm2_node); + transpose2->name("transpose2"); + + // Output + auto output = g->nodes()->create<luci::CircleOutput>(); + output->from(transpose2); + auto graph_output = g->outputs()->create(); + output->index(graph_output->index()); + output->name("output"); +} + +/** + * Remove for consecutive Transposes with branching + * + * BEFORE + * | + * [CircleNode] [CircleConst] + * \ / + * [CircleConst] [CircleTranspose] [CircleConst] + * \ / \ / + * [CircleTranspose] [CircleTranspose] + * | | + * [CircleNode] [CircleNode] + * | | + * + * AFTER + * Type 1: Remove all Transpose + * | + * [CircleNode] + * / \ + * [CircleNode] [CircleNode] + * | | + * + * Type 2: Remove both for one side and create new for another side + * | + * [CircleNode] [CircleConst](new) + * / \ / + * / [CircleTranspose](new) + * | | + * [CircleNode] [CircleNode] + * | | + */ +void create_redundunt_transpose_with_branch(loco::Graph *g, const std::vector<int32_t> &perm1, + const std::vector<int32_t> &perm2, + const std::vector<int32_t> &perm3) +{ + assert(g); + + auto input = g->nodes()->create<luci::CircleInput>(); + auto graph_input = g->inputs()->create(); + input->dtype(loco::DataType::FLOAT32); + input->index(graph_input->index()); + input->name("input"); + graph_input->dtype(loco::DataType::FLOAT32); + + graph_input->shape({4, 4, 4, 4}); + input->shape({4, 4, 4, 4}); + + // Create perm1 + auto perm1_node = g->nodes()->create<luci::CircleConst>(); + setValue(perm1_node, perm1); + perm1_node->name("perm1_node"); + + auto transpose1 = g->nodes()->create<luci::CircleTranspose>(); + transpose1->dtype(loco::DataType::FLOAT32); + transpose1->a(input); + transpose1->perm(perm1_node); + transpose1->name("transpose1"); + + // Create perm2 + auto perm2_node = g->nodes()->create<luci::CircleConst>(); + setValue(perm2_node, perm2); + perm2_node->name("perm2_node"); + + auto transpose2 = g->nodes()->create<luci::CircleTranspose>(); + transpose2->dtype(loco::DataType::FLOAT32); + transpose2->a(transpose1); + transpose2->perm(perm2_node); + transpose2->name("transpose2"); + + // create perm3 + auto perm3_node = g->nodes()->create<luci::CircleConst>(); + setValue(perm3_node, perm3); + perm3_node->name("perm3_node"); + + auto transpose3 = g->nodes()->create<luci::CircleTranspose>(); + transpose3->dtype(loco::DataType::FLOAT32); + transpose3->a(transpose1); + transpose3->perm(perm3_node); + transpose3->name("transpose3"); + + // Output + auto output1 = g->nodes()->create<luci::CircleOutput>(); + output1->from(transpose2); + output1->name("output1"); + auto output2 = g->nodes()->create<luci::CircleOutput>(); + output2->from(transpose3); + output2->name("output2"); + auto graph_output1 = g->outputs()->create(); + output1->index(graph_output1->index()); + auto graph_output2 = g->outputs()->create(); + output2->index(graph_output2->index()); + output1->dtype(loco::DataType::FLOAT32); + output2->dtype(loco::DataType::FLOAT32); + graph_output1->dtype(loco::DataType::FLOAT32); + graph_output2->dtype(loco::DataType::FLOAT32); + output1->shape({4, 4, 4, 4}); + output2->shape({4, 4, 4, 4}); + graph_output1->shape({4, 4, 4, 4}); + graph_output2->shape({4, 4, 4, 4}); +} + +} // namespace + +TEST(RemoveRedundantTransposePassTest, name) +{ + luci::RemoveRedundantTransposePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type1) +{ + auto graph = loco::make_graph(); + create_redundunt_transpose(graph.get(), {1, 0, 2, 3}, {1, 0, 2, 3}); + + luci::RemoveRedundantTransposePass pass; + while (pass.run(graph.get())) + ; + luci::CircleTranspose *transpose_node = nullptr; + for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) + { + auto trans = dynamic_cast<luci::CircleTranspose *>(node); + if (not trans) + continue; + transpose_node = trans; + break; + } + // No transpose node is in graph. + ASSERT_EQ(nullptr, transpose_node); +} + +TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2) +{ + auto graph = loco::make_graph(); + create_redundunt_transpose(graph.get(), {0, 1, 3, 2}, {1, 0, 2, 3}); + + luci::RemoveRedundantTransposePass pass; + while (pass.run(graph.get())) + ; + luci::CircleTranspose *transpose_node = nullptr; + for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) + { + auto trans = dynamic_cast<luci::CircleTranspose *>(node); + if (not trans) + continue; + transpose_node = trans; + break; + } + // Just one transpose node, with updated perm constant. + ASSERT_NE(nullptr, transpose_node); + auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm()); + ASSERT_EQ(1, perm->at<loco::DataType::S32>(0)); + ASSERT_EQ(0, perm->at<loco::DataType::S32>(1)); + ASSERT_EQ(3, perm->at<loco::DataType::S32>(2)); + ASSERT_EQ(2, perm->at<loco::DataType::S32>(3)); +} + +/** + * @brief Test case that first transpose output become input of operations more than one. + */ +TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_with_branch_remove_case) +{ + auto graph = loco::make_graph(); + create_redundunt_transpose_with_branch(graph.get(), {1, 0, 2, 3}, {1, 0, 2, 3}, {1, 0, 2, 3}); + + luci::RemoveRedundantTransposePass pass; + while (pass.run(graph.get())) + ; + luci::CircleTranspose *transpose_node = nullptr; + for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) + { + auto trans = dynamic_cast<luci::CircleTranspose *>(node); + if (not trans) + continue; + transpose_node = trans; + break; + } + // No transpose node is in graph. + ASSERT_EQ(nullptr, transpose_node); +} + +TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_with_branch_leave_one) +{ + auto graph = loco::make_graph(); + create_redundunt_transpose_with_branch(graph.get(), {1, 0, 2, 3}, {1, 0, 2, 3}, {0, 1, 3, 2}); + + luci::RemoveRedundantTransposePass pass; + while (pass.run(graph.get())) + ; + luci::CircleTranspose *transpose_node = nullptr; + for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) + { + auto trans = dynamic_cast<luci::CircleTranspose *>(node); + if (not trans) + continue; + transpose_node = trans; + break; + } + ASSERT_NE(nullptr, transpose_node); + auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm()); + ASSERT_EQ(1, perm->at<loco::DataType::S32>(0)); + ASSERT_EQ(0, perm->at<loco::DataType::S32>(1)); + ASSERT_EQ(3, perm->at<loco::DataType::S32>(2)); + ASSERT_EQ(2, perm->at<loco::DataType::S32>(3)); +} diff --git a/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp new file mode 100644 index 000000000..3f0c4ee82 --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveUnnecessaryReshapePass.h" + +#include <luci/IR/CircleNodes.h> + +namespace +{ + +bool remove_no_effect_reshape(luci::CircleNode *node) +{ + auto target_node = dynamic_cast<luci::CircleReshape *>(node); + if (target_node == nullptr) + return false; + + auto new_shape = dynamic_cast<luci::CircleConst *>(target_node->shape()); + if (new_shape == nullptr) + return false; + + // Compare updated shape and input shape. + auto input_node = loco::must_cast<luci::CircleNode *>(target_node->tensor()); + if (input_node->rank() != new_shape->dim(0).value()) + return false; + for (uint32_t i = 0; i < input_node->rank(); i++) + { + // If update_shape is -1, don't care + // TODO check updated shape has value -1 at most one. + if (new_shape->at<loco::DataType::S32>(i) == -1) + continue; + // If input_shape dynamic, can't remove this. + if (!input_node->dim(i).known()) + return false; + // If input_shape and updated shape differ, also can't remove. + if (input_node->dim(i).value() != static_cast<uint32_t>(new_shape->at<loco::DataType::S32>(i))) + return false; + } + + replace(target_node).with(input_node); + return true; +} + +} // namespace + +namespace luci +{ + +bool RemoveUnnecessaryReshapePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + if (remove_no_effect_reshape(circle_node)) + { + changed = true; + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.test.cpp b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.test.cpp new file mode 100644 index 000000000..9d2e758b4 --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.test.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveUnnecessaryReshapePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <luci/test/TestIOGraph.h> +#include "test/TestFirstNode.h" + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class ReshapeGraphlet +{ +public: + ReshapeGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 input_shape, bool remove) + { + std::vector<uint32_t> shape_vector{input_shape}; + + auto dim0_val = remove ? shape_vector.size() : 1; + _reshape_shape = g->nodes()->create<luci::CircleConst>(); + _reshape_shape->rank(1); + _reshape_shape->dim(0).set(dim0_val); + _reshape_shape->shape_status(luci::ShapeStatus::VALID); + _reshape_shape->dtype(loco::DataType::S32); + + _reshape_shape->size<loco::DataType::S32>(dim0_val); + for (uint32_t i = 0; i < dim0_val; i++) + { + if (remove) + _reshape_shape->at<loco::DataType::S32>(i) = static_cast<int32_t>(shape_vector.at(i)); + else + _reshape_shape->at<loco::DataType::S32>(i) = -1; + } + _reshape_shape->name("reshape_shape"); + + // Reshape create + auto newshape_rank = remove ? shape_vector.size() : 1; + _reshape = g->nodes()->create<luci::CircleReshape>(); + _reshape->newShape()->rank(newshape_rank); + for (uint32_t i = 0; i < newshape_rank; i++) + { + if (remove) + _reshape->newShape()->dim(i) = static_cast<int32_t>(shape_vector.at(i)); + else + _reshape->newShape()->dim(i) = -1; + } + _reshape->name("reshape"); + } + +protected: + luci::CircleReshape *_reshape = nullptr; + luci::CircleConst *_reshape_shape = nullptr; +}; + +class ReshapeGraph : public TestIOGraph, public ReshapeGraphlet +{ +public: + ReshapeGraph() = default; + +public: + void init(const ShapeU32 shape, bool remove) + { + TestIOGraph::init(shape, shape); + ReshapeGraphlet::init(g(), shape, remove); + + // connect graph + _reshape->tensor(input()); + _reshape->shape(_reshape_shape); + + output()->from(_reshape); + } +}; + +// TODO use ::testing::Test + +} // namespace + +TEST(RemoveUnnecessaryReshapePassTest, name) +{ + luci::RemoveUnnecessaryReshapePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(RemoveUnnecessaryReshapePass, removed) +{ + ReshapeGraph g; + + g.init({1, 2, 3, 4}, true); + + // confirm graph has Reshape + auto reshape_node = luci::test::first_node<luci::CircleReshape>(g.g()); + ASSERT_NE(nullptr, reshape_node); + luci::RemoveUnnecessaryReshapePass pass; + while (pass.run(g.g())) + ; + + // check Reshape is removed + reshape_node = luci::test::first_node<luci::CircleReshape>(g.g()); + ASSERT_EQ(nullptr, reshape_node); +} + +TEST(RemoveUnnecessaryReshapePass, not_removed_NEG) +{ + ReshapeGraph g; + + g.init({1, 2, 3, 4}, false); + + // confirm graph has Reshape + auto reshape_node = luci::test::first_node<luci::CircleReshape>(g.g()); + ASSERT_NE(nullptr, reshape_node); + luci::RemoveUnnecessaryReshapePass pass; + while (pass.run(g.g())) + ; + + // check Reshape is NOT removed + reshape_node = luci::test::first_node<luci::CircleReshape>(g.g()); + ASSERT_NE(nullptr, reshape_node); +} diff --git a/compiler/luci/pass/src/RemoveUnnecessarySlicePass.cpp b/compiler/luci/pass/src/RemoveUnnecessarySlicePass.cpp new file mode 100644 index 000000000..0720813cd --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessarySlicePass.cpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveUnnecessarySlicePass.h" + +#include <luci/IR/CircleNodes.h> + +namespace +{ + +/** + * @brief Return value in CircleConst. + * @details Return value in position on CircleConst with int64 format. + * Begin must be larger than or equal to 0. Size must be larger + * than or equal to -1. + */ +int64_t value_from_circle_const(const luci::CircleConst *node, uint32_t idx) +{ + assert(node->rank() == 1 && node->dim(0).value() > idx); + assert(node->dtype() == loco::DataType::S64 || node->dtype() == loco::DataType::S32); + + if (node->dtype() == loco::DataType::S64) + return node->at<loco::DataType::S64>(idx); + return static_cast<int64_t>(node->at<loco::DataType::S32>(idx)); +} + +bool remove_no_effect_slice(luci::CircleNode *node) +{ + auto target_node = dynamic_cast<luci::CircleSlice *>(node); + if (target_node == nullptr) + return false; + + auto begin_const = dynamic_cast<luci::CircleConst *>(target_node->begin()); + if (begin_const == nullptr) + return false; + + auto size_const = dynamic_cast<luci::CircleConst *>(target_node->size()); + if (size_const == nullptr) + return false; + + // Check input output shape. + auto input_node = loco::must_cast<luci::CircleNode *>(target_node->input()); + for (uint32_t i = 0; i < input_node->rank(); i++) + { + if (value_from_circle_const(begin_const, i) != 0) + return false; + + int64_t size_value = value_from_circle_const(size_const, i); + if (size_value == -1) + continue; + if (size_value != static_cast<int64_t>(input_node->dim(i).value())) + return false; + + if (!input_node->dim(i).known()) + return false; + } + replace(target_node).with(input_node); + return true; +} + +} // namespace + +namespace luci +{ +/** + * BEFORE + * + * [CircleNode] + * | + * [CircleSlice] + * | + * [CircleNode] + * + * AFTER + * + * [CircleNode] + * | + * [CircleNode] + * + * Slice OP has no effect if, + * 1. Static Shape : begin_const[idx] is 0 AND size_const[idx] is (-1 OR input_dimension[idx]) + * 2. Dynamic Shape : begin_const[idx] is 0 AND size_const[idx] is -1 + */ +bool RemoveUnnecessarySlicePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + if (remove_no_effect_slice(circle_node)) + { + changed = true; + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveUnnecessarySlicePass.test.cpp b/compiler/luci/pass/src/RemoveUnnecessarySlicePass.test.cpp new file mode 100644 index 000000000..80921a93a --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessarySlicePass.test.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "luci/Pass/RemoveUnnecessarySlicePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <luci/test/TestIOGraph.h> +#include "test/TestFirstNode.h" + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class SliceGraphlet +{ +public: + SliceGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 input_shape, bool remove) + { + // Begin Create. + _begin = g->nodes()->create<luci::CircleConst>(); + _begin->rank(1); + _begin->dim(0).set(input_shape.size()); + _begin->shape_status(luci::ShapeStatus::VALID); + _begin->dtype(loco::DataType::S32); + _begin->size<loco::DataType::S32>(input_shape.size()); + for (int i = 0; i < input_shape.size(); ++i) + _begin->at<loco::DataType::S32>(i) = remove ? 0 : 1; + _begin->name("begin"); + + // Size Create. + _size = g->nodes()->create<luci::CircleConst>(); + _size->rank(1); + _size->dim(0).set(input_shape.size()); + _size->shape_status(luci::ShapeStatus::VALID); + _size->dtype(loco::DataType::S32); + _size->size<loco::DataType::S32>(input_shape.size()); + for (int i = 0; i < input_shape.size(); ++i) + _size->at<loco::DataType::S32>(i) = -1; + _size->name("size"); + + // Slice Node create. + _slice = g->nodes()->create<luci::CircleSlice>(); + _slice->dtype(loco::DataType::S32); + _slice->name("slice"); + } + +protected: + luci::CircleSlice *_slice = nullptr; + luci::CircleConst *_begin = nullptr; + luci::CircleConst *_size = nullptr; +}; + +class SliceGraph : public TestIOGraph, public SliceGraphlet +{ +public: + SliceGraph() = default; + +public: + void init(const ShapeU32 shape, bool remove) + { + TestIOGraph::init(shape, shape); + SliceGraphlet::init(g(), shape, remove); + + _slice->input(input()); + _slice->begin(_begin); + _slice->size(_size); + + output()->from(_slice); + } +}; + +} // namespace + +TEST(RemoveUnnecessarySlicePass, name) +{ + luci::RemoveUnnecessarySlicePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(RemoveUnnecessarySlicePass, removed) +{ + SliceGraph g; + + g.init({2, 4, 2, 3}, true); + + // confirm graph has Slice + auto slice_node = luci::test::first_node<luci::CircleSlice>(g.g()); + ASSERT_NE(nullptr, slice_node); + luci::RemoveUnnecessarySlicePass pass; + while (pass.run(g.g())) + ; + + // check Slice is removed + slice_node = luci::test::first_node<luci::CircleSlice>(g.g()); + ASSERT_EQ(nullptr, slice_node); +} + +TEST(RemoveUnnecessarySlicePass, not_removed_NEG) +{ + SliceGraph g; + + g.init({2, 4, 2, 3}, false); + + // confirm graph has Slice + auto slice_node = luci::test::first_node<luci::CircleSlice>(g.g()); + ASSERT_NE(nullptr, slice_node); + luci::RemoveUnnecessarySlicePass pass; + while (pass.run(g.g())) + ; + + // check Slice is NOT removed + slice_node = luci::test::first_node<luci::CircleSlice>(g.g()); + ASSERT_NE(nullptr, slice_node); +} diff --git a/compiler/luci/pass/src/ShapeSignatureInferencePass.cpp b/compiler/luci/pass/src/RemoveUnnecessarySplitPass.cpp index 115b77a96..3243f6213 100644 --- a/compiler/luci/pass/src/ShapeSignatureInferencePass.cpp +++ b/compiler/luci/pass/src/RemoveUnnecessarySplitPass.cpp @@ -14,49 +14,50 @@ * limitations under the License. */ -#include "luci/Pass/ShapeSignatureInferencePass.h" +#include "luci/Pass/RemoveUnnecessarySplitPass.h" -#include <luci/IR/CircleShapeSignature.h> -#include <luci/Service/CircleShapeSignatureInference.h> +#include <luci/IR/CircleNodes.h> -#include <loco.h> - -namespace luci +namespace { - -bool ShapeSignatureInferencePass::run(luci::Module *m) +bool remove_unnecessary_split(luci::CircleNode *node) { - bool changed = false; + auto target_node = dynamic_cast<luci::CircleSplitOut *>(node); + if (target_node == nullptr) + return false; + + auto split_node = dynamic_cast<luci::CircleSplit *>(target_node->input()); + if (split_node == nullptr) + return false; - for (size_t g = 0; g < m->size(); ++g) + if (loco::succs(split_node).size() != 1) + return false; + + if (split_node->num_split() == 1) { - if (run(m->graph(g))) - changed = true; + auto input_node = loco::must_cast<luci::CircleNode *>(split_node->input()); + replace(target_node).with(input_node); + return true; } - - return changed; + return false; } -bool ShapeSignatureInferencePass::run(loco::Graph *g) +} // namespace + +namespace luci { - luci::ssinf::Rule signature_inference_rule; - bool changed = false; - for (auto node : loco::postorder_traversal(loco::output_nodes(g))) +bool RemoveUnnecessarySplitPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) { - luci::ShapeSignature shape_signature; - auto circle_node = loco::must_cast<luci::CircleNode *>(node); - if (signature_inference_rule.infer(circle_node, shape_signature)) + if (remove_unnecessary_split(circle_node)) { - if (!(circle_node->shape_signature() == shape_signature)) - { - circle_node->shape_signature(shape_signature); - changed = true; - } + changed = true; } } - return changed; } diff --git a/compiler/luci/pass/src/RemoveUnnecessarySplitPass.test.cpp b/compiler/luci/pass/src/RemoveUnnecessarySplitPass.test.cpp new file mode 100644 index 000000000..f292b5357 --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessarySplitPass.test.cpp @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveUnnecessarySplitPass.h" + +#include <luci/IR/CircleNodes.h> + +#include <luci/test/TestIOGraph.h> +#include "test/TestFirstNode.h" + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class SplitGraphlet +{ +public: + SplitGraphlet() = default; + +public: + void init(loco::Graph *g, uint32_t nout) + { + assert(nout == 1 || nout == 2); + + _dim = g->nodes()->create<luci::CircleConst>(); + set_shape_vector(_dim, {0}); + _dim->name("dim"); + + _split = g->nodes()->create<luci::CircleSplit>(); + _split->num_split(nout); + _split->name("split"); + + _split_out_0 = g->nodes()->create<luci::CircleSplitOut>(); + _split_out_0->index(0); + _split_out_0->name("split_out_0"); + + if (nout == 2) + { + _split_out_1 = g->nodes()->create<luci::CircleSplitOut>(); + _split_out_1->index(1); + _split_out_1->name("split_out_1"); + } + } + +protected: + luci::CircleSplit *_split = nullptr; + luci::CircleConst *_dim = nullptr; + luci::CircleSplitOut *_split_out_0 = nullptr; + luci::CircleSplitOut *_split_out_1 = nullptr; +}; + +class SplitOneGraph : public TestIGraphlet, public TestOGraphlet, public SplitGraphlet +{ +public: + SplitOneGraph() = default; + +public: + void init() + { + TestIGraphlet::init(g(), {1}); + TestOGraphlet::init(g(), {1}); + SplitGraphlet::init(g(), 1); + + _split->input(input()); + _split->split_dim(_dim); + _split_out_0->input(_split); + + output()->from(_split_out_0); + } +}; + +class SplitTwoGraph : public TestIGraphlet, public TestOsGraphlet<2>, public SplitGraphlet +{ +public: + SplitTwoGraph() = default; + +public: + void init() + { + TestIGraphlet::init(g(), {1}); + TestOsGraphlet<2>::init(g(), {{1}, {1}}); + SplitGraphlet::init(g(), 2); + + _split->input(input()); + _split->split_dim(_dim); + _split_out_0->input(_split); + _split_out_1->input(_split); + + output(0)->from(_split_out_0); + output(1)->from(_split_out_1); + } +}; + +// TODO use ::testing::Test + +} // namespace + +TEST(RemoveUnnecessarySplitPass, name) +{ + luci::RemoveUnnecessarySplitPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(RemoveUnnecessarySplitPass, create_unnecessary_split) +{ + SplitOneGraph g; + + g.init(); + + luci::RemoveUnnecessarySplitPass pass; + while (pass.run(g.g())) + ; + + auto split_node = luci::test::first_node<luci::CircleSplit>(g.g()); + // No Split node is in graph. + ASSERT_EQ(nullptr, split_node); +} + +TEST(RemoveUnnecessarySplitPass, create_unnecessary_split_NEG) +{ + SplitTwoGraph g; + + g.init(); + + luci::RemoveUnnecessarySplitPass pass; + while (pass.run(g.g())) + ; + + auto split_node = luci::test::first_node<luci::CircleSplit>(g.g()); + // Split node is in graph. + ASSERT_NE(nullptr, split_node); +} diff --git a/compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.cpp new file mode 100644 index 000000000..22b1aa64f --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h" + +#include <luci/IR/CircleNodes.h> + +namespace +{ + +/** + * @brief Return value in CircleConst. + * @details Return value in position on CircleConst with int64 format. + */ +int64_t value_from_circle_const(const luci::CircleConst *node, uint32_t idx) +{ + assert(node->rank() == 1 && node->dim(0).value() > idx); + assert(node->dtype() == loco::DataType::S64 || node->dtype() == loco::DataType::S32); + + if (node->dtype() == loco::DataType::S64) + return node->at<loco::DataType::S64>(idx); + return static_cast<int64_t>(node->at<loco::DataType::S32>(idx)); +} + +bool remove_no_effect_strided_slice(luci::CircleStridedSlice *target_node) +{ + auto begin_const = dynamic_cast<luci::CircleConst *>(target_node->begin()); + if (begin_const == nullptr) + return false; + + auto strides_const = dynamic_cast<luci::CircleConst *>(target_node->strides()); + if (strides_const == nullptr) + return false; + + auto end_const = dynamic_cast<luci::CircleConst *>(target_node->end()); + if (end_const == nullptr) + return false; + + auto input_node = loco::must_cast<luci::CircleNode *>(target_node->input()); + for (uint32_t i = 0; i < input_node->rank(); i++) + { + if (value_from_circle_const(begin_const, i) != 0) + return false; + + int64_t strides_value = value_from_circle_const(strides_const, i); + if (strides_value != 1) + return false; + + int64_t end_value = value_from_circle_const(end_const, i); + if (end_value == -1) + continue; + + if (end_value != input_node->dim(i).value()) + return false; + + if (!input_node->dim(i).known()) + return false; + } + + /** + * We check additional attributes on zero after shapes + * for skipping wrong StridedSlice operator. + */ + if (target_node->new_axis_mask() != 0 || target_node->shrink_axis_mask() != 0) + return false; + + replace(target_node).with(input_node); + return true; +} + +} // namespace + +namespace luci +{ +/** + * BEFORE + * + * [CircleNode] + * | + * [CircleStridedSlice] + * | + * [CircleNode] + * + * AFTER + * + * [CircleNode] + * | + * [CircleNode] [CircleStridedSlice] + * + * StridedSlice OP has no effect if, + * 1. Static Shape : begin_const[idx] is 0 AND strides_const[idx] is (not 1 OR + * input_dimension[idx]) + * 2. Dynamic Shape : begin_const[idx] is 0 AND strides_const[idx] is not 1 + * + * StridedSlice OP has effect if, + * 1. begin_const[idx] is 0 AND input_shape[idx] are equal to end_shape[idx] + */ +bool RemoveUnnecessaryStridedSlicePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto target_node = dynamic_cast<luci::CircleStridedSlice *>(node); + if (target_node != nullptr) + if (remove_no_effect_strided_slice(target_node)) + changed = true; + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.test.cpp b/compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.test.cpp new file mode 100644 index 000000000..7d611c864 --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.test.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <luci/test/TestIOGraph.h> +#include "test/TestFirstNode.h" + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class StridedSliceGraphlet +{ +public: + StridedSliceGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 input_shape, bool remove) + { + // Begin create + _begin = g->nodes()->create<luci::CircleConst>(); + _begin->rank(1); + _begin->dim(0).set(input_shape.size()); + _begin->shape_status(luci::ShapeStatus::VALID); + _begin->dtype(loco::DataType::S32); + _begin->size<loco::DataType::S32>(input_shape.size()); + for (int i = 0; i < input_shape.size(); ++i) + { + _begin->at<loco::DataType::S32>(i) = remove ? 0 : 1; + } + + // Strides create + _strides = g->nodes()->create<luci::CircleConst>(); + _strides->rank(1); + _strides->dim(0).set(input_shape.size()); + _strides->shape_status(luci::ShapeStatus::VALID); + _strides->dtype(loco::DataType::S32); + _strides->size<loco::DataType::S32>(input_shape.size()); + for (int i = 0; i < input_shape.size(); ++i) + { + _strides->at<loco::DataType::S32>(i) = remove ? 1 : -1; + } + + std::vector<uint32_t> shape_vector{input_shape}; + + _end = g->nodes()->create<luci::CircleConst>(); + _end->rank(1); + _end->dim(0).set(input_shape.size()); + _end->shape_status(luci::ShapeStatus::VALID); + _end->dtype(loco::DataType::S32); + _end->size<loco::DataType::S32>(input_shape.size()); + for (int i = 0; i < input_shape.size(); ++i) + { + if (remove) + _end->at<loco::DataType::S32>(i) = static_cast<int32_t>(shape_vector.at(i)); + else + _end->at<loco::DataType::S32>(i) = -1; + } + + // StridedSlice Node create + _strided_slice = g->nodes()->create<luci::CircleStridedSlice>(); + _strided_slice->dtype(loco::DataType::S32); + } + +protected: + luci::CircleStridedSlice *_strided_slice = nullptr; + luci::CircleConst *_begin = nullptr; + luci::CircleConst *_strides = nullptr; + luci::CircleConst *_end = nullptr; +}; + +class StridedSliceGraph : public TestIOGraph, public StridedSliceGraphlet +{ +public: + StridedSliceGraph() = default; + +public: + void init(const ShapeU32 shape, bool remove) + { + TestIOGraph::init(shape, shape); + StridedSliceGraphlet::init(g(), shape, remove); + + _strided_slice->input(input()); + _strided_slice->begin(_begin); + _strided_slice->strides(_strides); + _strided_slice->end(_end); + + output()->from(_strided_slice); + } +}; + +} // namespace + +TEST(RemoveUnnecessaryStridedSlicePass, basic_case) +{ + StridedSliceGraph g; + + g.init({2, 4, 2, 3}, true); + + auto strided_slice_node = luci::test::first_node<luci::CircleStridedSlice>(g.g()); + ASSERT_NE(nullptr, strided_slice_node); + luci::RemoveUnnecessaryStridedSlicePass pass; + while (pass.run(g.g())) + ; + + strided_slice_node = luci::test::first_node<luci::CircleStridedSlice>(g.g()); + ASSERT_EQ(nullptr, strided_slice_node); +} + +TEST(RemoveUnnecessaryStridedSlicePass, basic_fail_case_NEG) +{ + StridedSliceGraph g; + + g.init({2, 4, 2, 3}, false); + + auto strided_slice_node = luci::test::first_node<luci::CircleStridedSlice>(g.g()); + ASSERT_NE(nullptr, strided_slice_node); + luci::RemoveUnnecessaryStridedSlicePass pass; + while (pass.run(g.g())) + ; + + strided_slice_node = luci::test::first_node<luci::CircleStridedSlice>(g.g()); + ASSERT_NE(nullptr, strided_slice_node); +} diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp index 7096c2591..a0cc0194f 100644 --- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp +++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp @@ -16,7 +16,10 @@ #include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h" +#include "BatchNormPatternFinder.h" + #include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> namespace { @@ -26,6 +29,9 @@ luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma) assert(gamma->rank() == 1); auto channel_size = gamma->dim(0).value(); + auto name = gamma->name(); + assert(name.length() > 0); + // Channel-wise MUL is the same as DEPTHWISE_CONV2D with filter shape (1,1,1,channel_size) auto weights = gamma->graph()->nodes()->create<luci::CircleConst>(); weights->dtype(loco::DataType::FLOAT32); @@ -40,6 +46,7 @@ luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma) { weights->at<loco::DataType::FLOAT32>(i) = gamma->at<loco::DataType::FLOAT32>(i); } + weights->name(name + "_weights"); return weights; } @@ -49,6 +56,9 @@ luci::CircleConst *create_bias_from_beta(luci::CircleConst *beta) assert(beta->rank() == 1); auto channel_size = beta->dim(0).value(); + auto name = beta->name(); + assert(name.length() > 0); + // Channel-wise ADD is the same as bias (shape = (channel_size)) of DEPTHWISE_CONV2D auto bias = beta->graph()->nodes()->create<luci::CircleConst>(); bias->dtype(loco::DataType::FLOAT32); @@ -60,83 +70,11 @@ luci::CircleConst *create_bias_from_beta(luci::CircleConst *beta) { bias->at<loco::DataType::FLOAT32>(i) = beta->at<loco::DataType::FLOAT32>(i); } + bias->name(name + "_bias"); return bias; } -bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::CircleConst *&beta) -{ - auto x = loco::must_cast<luci::CircleNode *>(add->x()); - auto y = loco::must_cast<luci::CircleNode *>(add->y()); - - luci::CircleMul *pred = nullptr; - luci::CircleConst *constant = nullptr; - - if (x->opcode() == luci::CircleOpcode::CIRCLECONST && y->opcode() == luci::CircleOpcode::MUL) - { - pred = loco::must_cast<luci::CircleMul *>(y); - constant = loco::must_cast<luci::CircleConst *>(x); - } - else if (x->opcode() == luci::CircleOpcode::MUL && y->opcode() == luci::CircleOpcode::CIRCLECONST) - { - pred = loco::must_cast<luci::CircleMul *>(x); - constant = loco::must_cast<luci::CircleConst *>(y); - } - else - { - return false; - } - - if (constant->rank() != 1) - return false; - - auto channel_dim = constant->dim(0); - // Assumption: Layout is channel-last - if (!(channel_dim == add->dim(add->rank() - 1))) - return false; - - mul = pred; - beta = constant; - return true; -} - -// Check if mul is batchnorm mul -bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node, - luci::CircleConst *&gamma) -{ - auto x = dynamic_cast<luci::CircleConst *>(mul->x()); - auto y = dynamic_cast<luci::CircleConst *>(mul->y()); - - luci::CircleNode *pred = nullptr; - luci::CircleConst *constant = nullptr; - - if (x != nullptr && y == nullptr) - { - pred = loco::must_cast<luci::CircleNode *>(mul->y()); - constant = x; - } - else if (x == nullptr && y != nullptr) - { - pred = loco::must_cast<luci::CircleNode *>(mul->x()); - constant = y; - } - else - { - return false; - } - - if (constant->rank() != 1) - return false; - - auto channel_dim = constant->dim(0); - if (!(channel_dim == mul->dim(mul->rank() - 1))) - return false; - - pred_node = pred; - gamma = constant; - return true; -} - /** * Replace channel-wise Mul/Add with DepthwiseConv2D * @@ -180,6 +118,9 @@ bool replace_mul_add_with_dwconv(luci::CircleAdd *add) auto weights = create_weights_from_gamma(gamma); auto bias = create_bias_from_beta(beta); + auto name = add->name(); + assert(name.length() > 0); + auto dwconv = add->graph()->nodes()->create<luci::CircleDepthwiseConv2D>(); dwconv->input(pred_node); dwconv->filter(weights); @@ -191,6 +132,8 @@ bool replace_mul_add_with_dwconv(luci::CircleAdd *add) dwconv->dilation()->w(1); dwconv->dilation()->h(1); dwconv->fusedActivationFunction(add->fusedActivationFunction()); + dwconv->name(name + "/DepthwiseConv2D"); + luci::add_origin(dwconv, luci::composite_origin({luci::get_origin(mul), luci::get_origin(add)})); loco::replace(add).with(dwconv); return true; @@ -206,14 +149,10 @@ bool ReplaceMulAddWithDepthwiseConvPass::run(loco::Graph *g) bool changed = false; for (auto node : loco::active_nodes(loco::output_nodes(g))) { - auto add = dynamic_cast<luci::CircleAdd *>(node); - if (not add) - continue; - - if (replace_mul_add_with_dwconv(add)) + if (auto add = dynamic_cast<luci::CircleAdd *>(node)) { - changed = true; - break; + if (replace_mul_add_with_dwconv(add)) + changed = true; } } diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp index a90182aaa..903d4dcc9 100644 --- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp +++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp @@ -85,6 +85,13 @@ public: add->x(mul); add->y(beta); output->from(add); + + input->name("input"); + mul->name("mul"); + gamma->name("gamma"); + add->name("add"); + beta->name("beta"); + output->name("output"); } public: @@ -99,6 +106,13 @@ public: } // namespace +TEST(ReplaceMulAddWithDepthwiseConv, name) +{ + luci::ReplaceMulAddWithDepthwiseConvPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + TEST(ReplaceMulAddWithDepthwiseConv, simple) { SimpleGraph g; diff --git a/compiler/luci/pass/src/RequantizePass.cpp b/compiler/luci/pass/src/RequantizePass.cpp index fe84e3bc3..a56536251 100644 --- a/compiler/luci/pass/src/RequantizePass.cpp +++ b/compiler/luci/pass/src/RequantizePass.cpp @@ -113,7 +113,7 @@ void requant_const_int8_to_uint8(CircleConst *node) struct RequantizeNonConst final : public luci::CircleNodeMutableVisitor<bool> { RequantizeNonConst(loco::DataType input, loco::DataType output) - : _input_type(input), _output_type(output) + : _input_type(input), _output_type(output) { } @@ -157,7 +157,7 @@ struct RequantizeNonConst final : public luci::CircleNodeMutableVisitor<bool> struct RequantizeConst final : public luci::CircleNodeMutableVisitor<bool> { RequantizeConst(loco::DataType input, loco::DataType output) - : _input_type(input), _output_type(output) + : _input_type(input), _output_type(output) { } diff --git a/compiler/luci/pass/src/RequantizePass.test.cpp b/compiler/luci/pass/src/RequantizePass.test.cpp new file mode 100644 index 000000000..d26743c9d --- /dev/null +++ b/compiler/luci/pass/src/RequantizePass.test.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RequantizePass.h" + +#include <gtest/gtest.h> + +TEST(RequantizePassTest, name) +{ + luci::RequantizePass pass(loco::DataType::FLOAT32, loco::DataType::U8); + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/ResolveCustomOpAddPass.cpp b/compiler/luci/pass/src/ResolveCustomOpAddPass.cpp index e52d667d7..1737e5dd6 100644 --- a/compiler/luci/pass/src/ResolveCustomOpAddPass.cpp +++ b/compiler/luci/pass/src/ResolveCustomOpAddPass.cpp @@ -20,6 +20,7 @@ #include <luci/IR/CircleNodes.h> #include <luci/IR/AttrFusedActFunc.h> +#include <luci/Profile/CircleNodeOrigin.h> namespace { @@ -67,10 +68,17 @@ bool resolve_with_BroadcastTo(luci::CircleCustom *addv2) auto input = loco::must_cast<const luci::CircleCustomOut *>(addv2->inputs(broadcastTo_idx)); auto broadcastTo = loco::must_cast<luci::CircleCustom *>(input->input()); + auto name = addv2->name(); + assert(name.length() > 0); + auto add = addv2->graph()->nodes()->create<luci::CircleAdd>(); add->fusedActivationFunction(luci::FusedActFunc::NONE); add->x(addv2->inputs(1 - broadcastTo_idx)); add->y(broadcastTo->inputs(0)); + add->name(name + "/Add"); + luci::add_origin( + add, luci::composite_origin({luci::get_origin(broadcastTo), luci::get_origin(addv2)})); + auto customOut = loco::succs(addv2); assert(customOut.size() == 1); replace(*customOut.begin()).with(add); @@ -86,13 +94,39 @@ bool resolve_custom_op(luci::CircleCustom *addv2) if (custom_code != "AddV2") return false; + if (addv2->numInputs() != 2) + return false; + + // check if inputs are suppport data types + for (uint32_t i = 0; i < addv2->numInputs(); i++) + { + auto input = loco::must_cast<luci::CircleNode *>(addv2->inputs(i)); + switch (input->dtype()) + { + case loco::DataType::U8: + case loco::DataType::S8: + case loco::DataType::S16: + case loco::DataType::S32: + case loco::DataType::FLOAT32: + break; + default: + return false; + } + } + if (resolve_with_BroadcastTo(addv2)) return true; + auto name = addv2->name(); + assert(name.length() > 0); + auto add = addv2->graph()->nodes()->create<luci::CircleAdd>(); add->fusedActivationFunction(luci::FusedActFunc::NONE); add->x(addv2->inputs(0)); add->y(addv2->inputs(1)); + add->name(name + "/Add"); + luci::add_origin(add, luci::get_origin(addv2)); + auto customOut = loco::succs(addv2); assert(customOut.size() == 1); replace(*customOut.begin()).with(add); @@ -115,7 +149,8 @@ bool ResolveCustomOpAddPass::run(loco::Graph *g) if (not cop) continue; - changed |= resolve_custom_op(cop); + if (resolve_custom_op(cop)) + changed = true; } return changed; diff --git a/compiler/luci/pass/src/ResolveCustomOpAddPass.test.cpp b/compiler/luci/pass/src/ResolveCustomOpAddPass.test.cpp new file mode 100644 index 000000000..31c245b0e --- /dev/null +++ b/compiler/luci/pass/src/ResolveCustomOpAddPass.test.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ResolveCustomOpAddPass.h" + +#include <gtest/gtest.h> + +TEST(ResolveCustomOpAddPassTest, name) +{ + luci::ResolveCustomOpAddPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.cpp b/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.cpp index 145e9cb62..5e9466a63 100644 --- a/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.cpp +++ b/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.cpp @@ -19,6 +19,7 @@ #include "flatbuffers/flexbuffers.h" #include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> namespace { @@ -30,6 +31,9 @@ bool resolve_custom_op(luci::CircleCustom *cop) if (custom_code == "BatchMatMulV2") { + auto name = cop->name(); + assert(name.length() > 0); + auto batch_matmul = cop->graph()->nodes()->create<luci::CircleBatchMatMul>(); // input batch_matmul->x(cop->inputs(0)); @@ -39,10 +43,16 @@ bool resolve_custom_op(luci::CircleCustom *cop) auto map = flexbuffers::GetRoot(custom_options).AsMap(); batch_matmul->adj_x(map["adj_x"].AsBool()); batch_matmul->adj_y(map["adj_y"].AsBool()); + batch_matmul->name(name + "/BatchMatMul"); + luci::add_origin(batch_matmul, luci::get_origin(cop)); + + auto customOut = loco::succs(cop); + assert(customOut.size() == 1); + replace(*customOut.begin()).with(batch_matmul); - replace(cop).with(batch_matmul); return true; } + return false; } @@ -51,6 +61,27 @@ bool resolve_custom_op(luci::CircleCustom *cop) namespace luci { +/** + * BEFORE + * | | + * [CircleNode] [CircleNode] + * \ / + * [CircleCustom]("BatchMatMulV2") + * | + * [CircleCustomOut] + * | + * [CircleNode] + * | + * + * AFTER + * | | + * [CircleNode] [CircleNode] + * \ / + * [CircleBatchMatMul] + * | + * [CircleNode] + * | + */ bool ResolveCustomOpBatchMatMulPass::run(loco::Graph *g) { bool changed = false; @@ -60,7 +91,8 @@ bool ResolveCustomOpBatchMatMulPass::run(loco::Graph *g) if (not cop) continue; - changed |= resolve_custom_op(cop); + if (resolve_custom_op(cop)) + changed = true; } return changed; diff --git a/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.test.cpp b/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.test.cpp new file mode 100644 index 000000000..435016f9d --- /dev/null +++ b/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.test.cpp @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ResolveCustomOpBatchMatMulPass.h" + +#include <luci/IR/CircleNodes.h> + +#include "flatbuffers/flatbuffers.h" +#include "flatbuffers/flexbuffers.h" + +#include <luci/test/TestIOGraph.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +const int N = 1; +const int C = 2; +const int H_X = 1; +const int W_X = 4; +const int H_Y = 4; +const int W_Y = 4; + +/** + * graph having Custom operator BatchMatMulV2 + * + * [CircleInput] [CircleInput] + * \ / + * [CircleCustom] + * | + * [CircleCustomOut] + * | + * [CircleOutput] + */ +class BatchMatmulV2Graphlet +{ +public: + BatchMatmulV2Graphlet() = default; + +public: + void init(loco::Graph *g) + { + // custom option + auto flatbuffer_builder = + std::unique_ptr<flatbuffers::FlatBufferBuilder>(new flatbuffers::FlatBufferBuilder(1024)); + auto flex_buffers = std::make_unique<flexbuffers::Builder>(); + size_t map_start = flex_buffers->StartMap(); + flex_buffers->Bool("adj_x", false); + flex_buffers->Bool("adj_y", false); + flex_buffers->Int("T", 0 /* circle::TensorType_FLOAT32 */); + flex_buffers->EndMap(map_start); + flex_buffers->Finish(); + + // CircleCustom(BatchMatMulV2, adj_x=False, adj_y=False) + _batchmatmulv2 = g->nodes()->create<luci::CircleCustom>(2, 1); + _batchmatmulv2->custom_code("BatchMatMulV2"); + _batchmatmulv2->custom_options(flex_buffers->GetBuffer()); + _batchmatmulv2->shape({N, C, H_X, W_Y}); + _batchmatmulv2->dtype(loco::DataType::FLOAT32); + _batchmatmulv2->name("batchmatmulv2"); + + // CircleCustomOut + _batchmatmulv2_out = g->nodes()->create<luci::CircleCustomOut>(); + _batchmatmulv2_out->shape({N, C, H_X, W_Y}); + _batchmatmulv2_out->dtype(loco::DataType::FLOAT32); + _batchmatmulv2_out->index(0); + } + +public: + luci::CircleCustom *batchmatmulv2() { return _batchmatmulv2; } + +protected: + luci::CircleCustom *_batchmatmulv2 = nullptr; + luci::CircleCustomOut *_batchmatmulv2_out = nullptr; +}; + +class BatchMatmulV2Graph : public TestIsGraphlet<2>, + public TestOGraphlet, + public BatchMatmulV2Graphlet +{ +public: + BatchMatmulV2Graph() = default; + + void init(void) + { + TestIsGraphlet<2>::init(g(), {{N, C, H_X, W_X}, {N, C, H_X, W_X}}); + TestOGraphlet::init(g(), {N, C, H_X, W_Y}); + BatchMatmulV2Graphlet::init(g()); + + // TODO how set multiple of shape vector for TestIsGraphlet? + // update shape for second input + input(1)->shape({N, C, H_Y, W_Y}); + + // connect graph + _batchmatmulv2->inputs(0, input(0)); + _batchmatmulv2->inputs(1, input(1)); + _batchmatmulv2_out->input(_batchmatmulv2); + + output()->from(_batchmatmulv2_out); + } +}; + +class BatchMatmulV2GraphTest : public ::testing::Test +{ +public: + BatchMatmulV2Graph g; + luci::ResolveCustomOpBatchMatMulPass pass; +}; + +} // namespace + +TEST(ResolveCustomOpBatchMatMulPassTest, name) +{ + luci::ResolveCustomOpBatchMatMulPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +/** + * Optimized graph looks like below. + * + * [CircleInput] + * | + * [CircleBatchMatMul] + * | + * [CircleOutput] + */ +TEST_F(BatchMatmulV2GraphTest, simple_test) +{ + g.init(); + + auto ret = pass.run(g.g()); + EXPECT_EQ(true, ret); + + auto batchmatmul = dynamic_cast<luci::CircleBatchMatMul *>(g.output()->from()); + EXPECT_NE(nullptr, batchmatmul); + + auto input_0 = dynamic_cast<luci::CircleInput *>(batchmatmul->x()); + auto input_1 = dynamic_cast<luci::CircleInput *>(batchmatmul->y()); + EXPECT_NE(nullptr, input_0); + EXPECT_NE(nullptr, input_1); +} + +TEST_F(BatchMatmulV2GraphTest, wrong_condition_NEG) +{ + g.init(); + + // wrong custom code + g.batchmatmulv2()->custom_code("BatchMatMulv2"); // v is lower case + auto ret = pass.run(g.g()); + + EXPECT_EQ(false, ret); +} diff --git a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp index 547fd22fc..216778066 100644 --- a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp +++ b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp @@ -20,11 +20,10 @@ #include <loco/IR/DataTypeTraits.h> #include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> #include <loco.h> #include <oops/InternalExn.h> -#include <loco/Service/ShapeInference.h> -#include <loco/Service/TypeInference.h> namespace { @@ -44,6 +43,7 @@ luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype, node->dim(i) = shape.at(i); size *= shape.at(i); } + node->shape_status(luci::ShapeStatus::VALID); #define INIT_VALUES(DT) \ { \ @@ -90,6 +90,9 @@ bool resolve_matmul(luci::CircleCustom *cop) const auto S32 = loco::DataType::S32; const auto FLOAT32 = loco::DataType::FLOAT32; + auto name = cop->name(); + assert(name.length() > 0); + bool transpose_a = map["transpose_a"].AsBool(); bool transpose_b = map["transpose_b"].AsBool(); @@ -97,34 +100,38 @@ bool resolve_matmul(luci::CircleCustom *cop) loco::Node *rhs = cop->inputs(1); // Check that the type of the first input is known - CHECK_OR_FALSE(loco::dtype_known(lhs)); - auto lhs_dtype = loco::dtype_get(cop->inputs(0)); + auto lhs_dtype = loco::must_cast<luci::CircleNode *>(cop->inputs(0))->dtype(); + CHECK_OR_FALSE(lhs_dtype != loco::DataType::Unknown); // If transpose of first input is requested, its shape must be known - CHECK_OR_FALSE(!transpose_a || loco::shape_known(lhs)); + auto circle_lhs = loco::must_cast<luci::CircleNode *>(lhs); + CHECK_OR_FALSE(!transpose_a || circle_lhs->shape_status() == luci::ShapeStatus::VALID); // and its rank should be at least 2 - CHECK_OR_FALSE(!transpose_a || loco::shape_get(lhs).as<loco::TensorShape>().rank() >= 2); + CHECK_OR_FALSE(!transpose_a || circle_lhs->rank() >= 2); // Check that the shape of the 2nd input is known - CHECK_OR_FALSE(loco::shape_known(rhs)); + auto circle_rhs = loco::must_cast<luci::CircleNode *>(rhs); + CHECK_OR_FALSE(circle_rhs->shape_status() == luci::ShapeStatus::VALID); // TODO as of 06/23/20 TFLite only supports rank 2 for 2nd input. Fix this once that changes! - CHECK_OR_FALSE(loco::shape_get(rhs).as<loco::TensorShape>().rank() == 2); + CHECK_OR_FALSE(circle_rhs->rank() == 2); // Check that input data type is supported CHECK_OR_THROW(lhs_dtype == U8 || lhs_dtype == S16 || lhs_dtype == FLOAT32, "Only UInt8, Int16 and Float32 data types are supported by MatMul"); if (transpose_a) { - auto a_shape = loco::shape_get(lhs).as<loco::TensorShape>(); // Create a permutation constant node std::vector<uint32_t> perm; - for (uint32_t i = 0; i < a_shape.rank(); ++i) + for (uint32_t i = 0; i < circle_lhs->rank(); ++i) perm.push_back(i); - std::swap(perm[a_shape.rank() - 1], perm[a_shape.rank() - 2]); - auto perm_node = create_const_node(graph, S32, {a_shape.rank()}, perm); + std::swap(perm[circle_lhs->rank() - 1], perm[circle_lhs->rank() - 2]); + auto perm_node = create_const_node(graph, S32, {circle_lhs->rank()}, perm); + perm_node->name(name + "/lhs/Transpose/perm"); // Now make a transpose node auto transpose_node = graph->nodes()->create<luci::CircleTranspose>(); transpose_node->a(lhs); transpose_node->perm(perm_node); + transpose_node->name(name + "/lhs/Transpose"); + luci::add_origin(transpose_node, luci::get_origin(cop)); lhs = transpose_node; } @@ -135,24 +142,29 @@ bool resolve_matmul(luci::CircleCustom *cop) { const std::vector<uint32_t> perm{1, 0}; auto perm_node = create_const_node(graph, S32, {2}, perm); + perm_node->name(name + "/rhs/Transpose/perm"); auto transpose_node = graph->nodes()->create<luci::CircleTranspose>(); transpose_node->a(rhs); transpose_node->perm(perm_node); + transpose_node->name(name + "/rhs/Transpose"); + luci::add_origin(transpose_node, luci::get_origin(cop)); rhs = transpose_node; } - // Make a constant zero-filled bias node - auto b_shape = loco::shape_get(cop->inputs(1)).as<loco::TensorShape>(); - uint32_t bias_size = b_shape.dim(transpose_b ? 1 : 0).value(); - const std::vector<float> val(bias_size, .0f); - auto bias_node = create_const_node(graph, lhs_dtype, {bias_size}, val); + auto empty_bias = graph->nodes()->create<luci::CircleOutputExclude>(); + empty_bias->dtype(loco::DataType::FLOAT32); // Needed for type inference + auto fc_node = graph->nodes()->create<luci::CircleFullyConnected>(); fc_node->input(lhs); fc_node->weights(rhs); - fc_node->bias(bias_node); + fc_node->bias(empty_bias); fc_node->fusedActivationFunction(luci::FusedActFunc::NONE); + fc_node->name(name + "/FullyConnected"); + luci::add_origin(fc_node, luci::get_origin(cop)); - replace(cop).with(fc_node); + auto customOut = loco::succs(cop); + assert(customOut.size() == 1); + replace(*customOut.begin()).with(fc_node); return true; } diff --git a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.test.cpp b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.test.cpp new file mode 100644 index 000000000..c4ea3ea06 --- /dev/null +++ b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.test.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ResolveCustomOpMatMulPass.h" + +#include <gtest/gtest.h> + +TEST(ResolveCustomOpMatMulPassTest, name) +{ + luci::ResolveCustomOpMatMulPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/ShapeInferencePass.cpp b/compiler/luci/pass/src/ShapeInferencePass.cpp deleted file mode 100644 index 4bd0aaed4..000000000 --- a/compiler/luci/pass/src/ShapeInferencePass.cpp +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "luci/Pass/ShapeInferencePass.h" - -#include <luci/IR/CircleDialect.h> -#include <luci/Service/CircleShapeInferenceRule.h> - -#include <loco.h> -#include <loco/IR/CanonicalDialect.h> -#include <loco/Service/CanonicalShapeInferenceRule.h> -#include <loco/Service/ShapeInference.h> -#include <loco/Service/MultiDialectShapeInferenceRule.h> - -namespace luci -{ - -bool ShapeInferencePass::run(luci::Module *m) -{ - bool changed = false; - - for (size_t g = 0; g < m->size(); ++g) - { - if (run(m->graph(g))) - changed = true; - } - - return changed; -} - -bool ShapeInferencePass::run(loco::Graph *g) -{ - loco::CanonicalShapeInferenceRule canonical_rule; - luci::CircleShapeInferenceRule circle_rule; - - loco::MultiDialectShapeInferenceRule rules; - - rules.bind(loco::CanonicalDialect::get(), &canonical_rule) - .bind(luci::CircleDialect::get(), &circle_rule); - - return loco::apply(&rules).to(g); -} - -} // namespace luci diff --git a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp index 6a58f18c5..92060f625 100644 --- a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp +++ b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp @@ -72,6 +72,9 @@ luci::CircleConst *shuffle_weight(luci::CircleFullyConnected *fc) { auto the_weights = loco::must_cast<luci::CircleConst *>(fc->weights()); + auto name = fc->name(); + assert(name.length() > 0); + // create CircleConst where shuffled data will be stored luci::CircleConst *new_weights = fc->graph()->nodes()->create<luci::CircleConst>(); new_weights->dtype(loco::DataType::FLOAT32); @@ -82,6 +85,7 @@ luci::CircleConst *shuffle_weight(luci::CircleFullyConnected *fc) { new_weights->dim(r).set(the_weights->dim(r).value()); } + new_weights->name(name + "/shuffle_weight"); // suffle weight const uint32_t MULTIPLE = 16; @@ -96,7 +100,7 @@ luci::CircleConst *shuffle_weight(luci::CircleFullyConnected *fc) for (uint32_t i = 0; i < MULTIPLE; i++) { new_weights->at<loco::DataType::FLOAT32>(index++) = - the_weights->at<loco::DataType::FLOAT32>((r * MULTIPLE + i) * cols + c); + the_weights->at<loco::DataType::FLOAT32>((r * MULTIPLE + i) * cols + c); } } } @@ -131,6 +135,8 @@ bool ShuffleWeightTo16x1Float32Pass::run(loco::Graph *g) fc->weights(new_weights); fc->weights_format(luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32); } + + changed = true; } return changed; diff --git a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp index 9745e5754..077985977 100644 --- a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp +++ b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp @@ -18,61 +18,86 @@ #include <luci/IR/CircleNodes.h> +#include <luci/test/TestIOGraph.h> +#include "test/TestFirstNode.h" + #include <gtest/gtest.h> -void create_fc_net(loco::Graph *g) +namespace { - assert(g); - - const uint32_t ROW = 16; - const uint32_t COL = 2; - const uint32_t elements_num = ROW * COL; - - // input - auto input = g->nodes()->create<luci::CircleInput>(); - auto graph_input = g->inputs()->create(); - input->index(graph_input->index()); - - // fc weights - auto weights = g->nodes()->create<luci::CircleConst>(); - weights->dtype(loco::DataType::FLOAT32); - weights->size<loco::DataType::FLOAT32>(elements_num); - weights->rank(2); - weights->dim(0).set(ROW); - weights->dim(1).set(COL); - for (uint32_t idx = 0; idx < elements_num; idx++) + +using namespace luci::test; + +class FCGraphlet +{ +public: + FCGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 wshape) { - weights->at<loco::DataType::FLOAT32>(idx) = idx; + const uint32_t elements_num = num_elements(wshape); + + // fc weights + _weights = g->nodes()->create<luci::CircleConst>(); + _weights->dtype(loco::DataType::FLOAT32); + _weights->shape(wshape); + _weights->size<loco::DataType::FLOAT32>(elements_num); + for (uint32_t idx = 0; idx < elements_num; idx++) + { + _weights->at<loco::DataType::FLOAT32>(idx) = idx; + } + _weights->name("weights"); + + // fc + _fc = g->nodes()->create<luci::CircleFullyConnected>(); + _fc->dtype(loco::DataType::FLOAT32); + _fc->name("fc"); } - // fc - auto fc = g->nodes()->create<luci::CircleFullyConnected>(); - fc->dtype(loco::DataType::FLOAT32); - fc->input(input); - fc->weights(weights); - - // output - auto output = g->nodes()->create<luci::CircleOutput>(); - output->from(fc); - auto graph_output = g->outputs()->create(); - output->index(graph_output->index()); -} +protected: + luci::CircleFullyConnected *_fc = nullptr; + luci::CircleConst *_weights = nullptr; +}; -TEST(ShuffleWeightTo16x1Float32PassTest, SimpleTest1) +class FCGraph : public TestIGraphlet, public TestOGraphlet, public FCGraphlet { - auto graph = loco::make_graph(); - create_fc_net(graph.get()); +public: + FCGraph() = default; - luci::CircleFullyConnected *fc_node = nullptr; - for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) + void init(const ShapeU32 shape, const ShapeU32 wshape) { - auto fc = dynamic_cast<luci::CircleFullyConnected *>(node); - if (not fc) - continue; + TestIGraphlet::init(g(), shape); + TestOGraphlet::init(g(), shape); + FCGraphlet::init(g(), wshape); + + // connect graph + _fc->input(input()); + _fc->weights(_weights); - fc_node = fc; - break; + output()->from(_fc); } +}; + +} // namespace + +TEST(ShuffleWeightTo16x1Float32PassTest, name) +{ + luci::ShuffleWeightTo16x1Float32Pass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +const uint32_t ROW = 16; +const uint32_t COL = 2; + +TEST(ShuffleWeightTo16x1Float32PassTest, SimpleTest1) +{ + FCGraph g; + + g.init({ROW, COL}, {ROW, COL}); + + auto fc_node = luci::test::first_node<luci::CircleFullyConnected>(g.g()); ASSERT_NE(fc_node, nullptr); auto weights = loco::must_cast<luci::CircleConst *>(fc_node->weights()); // before @@ -94,7 +119,7 @@ TEST(ShuffleWeightTo16x1Float32PassTest, SimpleTest1) ASSERT_EQ(15, weights->at<loco::DataType::FLOAT32>(15)); luci::ShuffleWeightTo16x1Float32Pass pass; - while (pass.run(graph.get())) + while (pass.run(g.g())) ; weights = loco::must_cast<luci::CircleConst *>(fc_node->weights()); @@ -116,3 +141,33 @@ TEST(ShuffleWeightTo16x1Float32PassTest, SimpleTest1) ASSERT_EQ(28, weights->at<loco::DataType::FLOAT32>(14)); ASSERT_EQ(30, weights->at<loco::DataType::FLOAT32>(15)); } + +TEST(ShuffleWeightTo16x1Float32PassTest, invalid_weight_shape_NEG) +{ + FCGraph g; + + g.init({ROW, COL}, {1, ROW, COL, 1}); + + auto fc_node = luci::test::first_node<luci::CircleFullyConnected>(g.g()); + ASSERT_NE(fc_node, nullptr); + + luci::ShuffleWeightTo16x1Float32Pass pass; + auto ret = pass.run(g.g()); + + ASSERT_FALSE(ret); +} + +TEST(ShuffleWeightTo16x1Float32PassTest, invalid_weight_row16_NEG) +{ + FCGraph g; + + g.init({COL, ROW}, {COL, ROW}); + + auto fc_node = luci::test::first_node<luci::CircleFullyConnected>(g.g()); + ASSERT_NE(fc_node, nullptr); + + luci::ShuffleWeightTo16x1Float32Pass pass; + auto ret = pass.run(g.g()); + + ASSERT_FALSE(ret); +} diff --git a/compiler/luci/pass/src/Sparsifier.cpp b/compiler/luci/pass/src/Sparsifier.cpp index 210c1a34c..18ab45f98 100644 --- a/compiler/luci/pass/src/Sparsifier.cpp +++ b/compiler/luci/pass/src/Sparsifier.cpp @@ -26,8 +26,8 @@ Sparsifier<T>::Sparsifier(const std::vector<int32_t> &shape, const std::vector<DimensionType> &format, const std::vector<int32_t> &block_size, const std::vector<int32_t> &block_map) - : _dense_shape(shape), _traversal_order(traversal_order), _block_size(block_size), - _block_map(block_map) + : _dense_shape(shape), _traversal_order(traversal_order), _block_size(block_size), + _block_map(block_map) { _dense_size = 1; int32_t block_dim = 0; diff --git a/compiler/luci/pass/src/Sparsifier.test.cpp b/compiler/luci/pass/src/Sparsifier.test.cpp index 272e0e934..14e24aad7 100644 --- a/compiler/luci/pass/src/Sparsifier.test.cpp +++ b/compiler/luci/pass/src/Sparsifier.test.cpp @@ -190,6 +190,6 @@ TEST(SparsifierTest, WrongFormatRank_NEG) const std::vector<int32_t> block_size = {4, 1}; const std::vector<int32_t> block_map = {0, 1}; EXPECT_THROW( - luci::Sparsifier<int32_t>(dense_shape, traversal_order, format, block_size, block_map), - std::out_of_range); + luci::Sparsifier<int32_t>(dense_shape, traversal_order, format, block_size, block_map), + std::out_of_range); } diff --git a/compiler/luci/pass/src/SparsifyTensorPass.cpp b/compiler/luci/pass/src/SparsifyTensorPass.cpp index 2f1a36e77..1a75bfb0c 100644 --- a/compiler/luci/pass/src/SparsifyTensorPass.cpp +++ b/compiler/luci/pass/src/SparsifyTensorPass.cpp @@ -69,11 +69,11 @@ template <loco::DataType DT> void SparsifyTensorPass::sparsify_tensor(luci::Circ else if (_format.at(idx) == DimensionType::SPARSE_CSR) { sparsityparam->dim_metadata.emplace_back( - DimensionType::SPARSE_CSR, /* dense size */ 0, - /* array_segments */ SparseIndexVector{SparseIndexVectorType::U16, - dim_metadata.at(idx * 2)}, - /* array_indices */ SparseIndexVector{SparseIndexVectorType::U16, - dim_metadata.at(idx * 2 + 1)}); + DimensionType::SPARSE_CSR, /* dense size */ 0, + /* array_segments */ + SparseIndexVector{SparseIndexVectorType::U16, dim_metadata.at(idx * 2)}, + /* array_indices */ + SparseIndexVector{SparseIndexVectorType::U16, dim_metadata.at(idx * 2 + 1)}); } } for (uint32_t i = 0; i < _block_size.size(); i++) diff --git a/compiler/luci/pass/src/SparsifyTensorPass.test.cpp b/compiler/luci/pass/src/SparsifyTensorPass.test.cpp new file mode 100644 index 000000000..372e8e5ca --- /dev/null +++ b/compiler/luci/pass/src/SparsifyTensorPass.test.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/SparsifyTensorPass.h" + +#include <gtest/gtest.h> + +TEST(SparsifyTensorPassTest, name) +{ + std::vector<int32_t> to; + std::vector<luci::DimensionType> vdt; + std::vector<int32_t> bs; + std::vector<int32_t> bm; + luci::SparsifyTensorPass pass("", to, vdt, bs, bm); + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp b/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp index 44e974b91..d8676cd62 100644 --- a/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp +++ b/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp @@ -17,10 +17,22 @@ #include "luci/Pass/SubstitutePackToReshapePass.h" #include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> namespace { +int32_t unknown_dim_count(luci::CircleNode *node) +{ + int32_t count = 0; + + for (uint32_t i = 0; i < node->rank(); ++i) + if (!node->dim(i).known()) + ++count; + + return count; +} + bool substitute_pack_to_reshape(luci::CircleNode *node) { auto target_node = dynamic_cast<luci::CirclePack *>(node); @@ -35,9 +47,14 @@ bool substitute_pack_to_reshape(luci::CircleNode *node) if (axis < 0) axis = axis + static_cast<int32_t>(value_node->rank()) + 1; + auto name = node->name(); + assert(name.length() > 0); + auto graph = target_node->graph(); auto reshape_node = graph->nodes()->create<luci::CircleReshape>(); reshape_node->tensor(value_node); + reshape_node->name(name + "/Reshape"); + luci::add_origin(reshape_node, luci::get_origin(node)); auto const_node = graph->nodes()->create<luci::CircleConst>(); const_node->dtype(loco::DataType::S32); @@ -53,13 +70,16 @@ bool substitute_pack_to_reshape(luci::CircleNode *node) } else if (i < axis) { - const_node->at<loco::DataType::S32>(i) = value_node->dim(i).value(); + const_node->at<loco::DataType::S32>(i) = + value_node->dim(i).known() ? value_node->dim(i).value() : -1; } else { - const_node->at<loco::DataType::S32>(i) = value_node->dim(i - 1).value(); + const_node->at<loco::DataType::S32>(i) = + value_node->dim(i - 1).known() ? value_node->dim(i - 1).value() : -1; } } + const_node->name(name + "/Reshape/shape"); reshape_node->shape(const_node); replace(target_node).with(reshape_node); return true; @@ -71,24 +91,23 @@ namespace luci { /** - * BEFORE - * | - * [CircleNode] - * | - * [CirclePack] - * | - * [CircleNode] - * | + * BEFORE + * | + * [CircleNode] + * | + * [CirclePack] + * | + * [CircleNode] + * | * - * AFTER - * | - * [CircleNode] [CircleConst] - * \ / - * [CircleReshape] + * AFTER * | - * [CircleNode] - * | - * + * [CircleNode] [CircleConst] + * | \ / + * [CirclePack] [CircleReshape] + * | + * [CircleNode] + * | */ bool SubstitutePackToReshapePass::run(loco::Graph *g) { @@ -96,7 +115,7 @@ bool SubstitutePackToReshapePass::run(loco::Graph *g) for (auto node : loco::active_nodes(loco::output_nodes(g))) { auto circle_node = loco::must_cast<luci::CircleNode *>(node); - if (substitute_pack_to_reshape(circle_node)) + if (unknown_dim_count(circle_node) <= 1 && substitute_pack_to_reshape(circle_node)) { changed = true; } diff --git a/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp b/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp index 143b88896..3b5d4ea2c 100644 --- a/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp +++ b/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp @@ -22,26 +22,6 @@ namespace { -/** - * BEFORE - * | - * [CircleNode] - * | - * [CirclePack] - * | - * [CircleNode] - * | - * - * AFTER - * | - * [CircleNode] [CircleConst] - * \ / - * [CircleReshape] - * | - * [CircleNode] - * | - * - */ void create_substitute_pack_to_reshape(loco::Graph *g, const std::initializer_list<uint32_t> shape, int32_t axis) { @@ -54,23 +34,33 @@ void create_substitute_pack_to_reshape(loco::Graph *g, const std::initializer_li input->shape_status(luci::ShapeStatus::VALID); input->rank(shape.size()); input->shape(shape); + input->name("input"); // Pack Node create. auto pack = g->nodes()->create<luci::CirclePack>(1); pack->values(0, input); pack->axis(axis); + pack->name("pack"); // Output Connect. auto output = g->nodes()->create<luci::CircleOutput>(); output->from(pack); auto graph_output = g->outputs()->create(); output->index(graph_output->index()); + output->name("output"); return; } } // namespace +TEST(SubstitutePackToReshapePassTest, name) +{ + luci::SubstitutePackToReshapePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + TEST(SubstitutePackToReshapePass, simple_case) { auto graph = loco::make_graph(); diff --git a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp new file mode 100644 index 000000000..74be86a4c --- /dev/null +++ b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/SubstituteSqueezeToReshapePass.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> + +namespace +{ + +/** + * @brief return TRUE if all dim is known + * @note This pass can be applied even some of dimensions are unknown. + For now, do not consider about it and update logic later. + */ +bool can_squeeze_shape(const luci::CircleNode *node) +{ + for (uint32_t r = 0; r < node->rank(); ++r) + { + if (not node->dim(r).known()) + return false; + } + return true; +} + +/** + * @brief return valid unsigned dim value from 0 ~ (rank-1) + * @note dim can be -rank to (rank-1) + */ +uint32_t valid_unsigned_dim(uint32_t rank, int32_t dim) +{ + int32_t irank = static_cast<int32_t>(rank); + return dim >= 0 ? static_cast<uint32_t>(dim) : static_cast<uint32_t>(irank + dim); +} + +/** + * @brief return TRUE if input dim is 1 for squeeze_dims values + */ +bool is_valid_input(const luci::CircleNode *node, const std::vector<int32_t> &squeeze_dims) +{ + auto rank = node->rank(); + for (auto dim : squeeze_dims) + { + auto udim = valid_unsigned_dim(rank, dim); + if (node->dim(udim).value() != 1) + return false; + } + return true; +} + +/** + * @brief return shape vector from input + */ +std::vector<uint32_t> node_shape(const luci::CircleNode *input) +{ + std::vector<uint32_t> shape; + uint32_t rank = input->rank(); + for (uint32_t r = 0; r < rank; ++r) + shape.push_back(input->dim(r).value()); + + return shape; +} + +/** + * @brief return CircleConst ptr with values of new_shape + */ +luci::CircleConst *create_shape_const(loco::Graph *graph, const std::vector<uint32_t> &new_shape) +{ + // NOTE dim_size can be 0 + uint32_t dim_size = static_cast<uint32_t>(new_shape.size()); + + auto shape_const = graph->nodes()->create<luci::CircleConst>(); + + // const shape/dtype + shape_const->dtype(loco::DataType::S32); + if (dim_size > 0) + { + shape_const->rank(1); + shape_const->dim(0).set(dim_size); + } + else + shape_const->rank(0); + shape_const->shape_status(luci::ShapeStatus::VALID); + + // constant values + shape_const->size<loco::DataType::S32>(dim_size); + for (uint32_t i = 0; i < dim_size; ++i) + shape_const->at<loco::DataType::S32>(i) = new_shape.at(i); + + return shape_const; +} + +bool substitute_squeeze_to_reshape(luci::CircleSqueeze *squeeze) +{ + assert(squeeze != nullptr); + + auto input = loco::must_cast<luci::CircleNode *>(squeeze->input()); + // we need input node shape and all dim should be known + if (input->shape_status() != luci::ShapeStatus::VALID) + return false; + if (not can_squeeze_shape(input)) + return false; + + // we will use squeeze shape for new shape + if (squeeze->shape_status() != luci::ShapeStatus::VALID) + return false; + + auto squeeze_dims = squeeze->squeeze_dims(); + if (not is_valid_input(input, squeeze_dims)) + throw std::runtime_error("Invalid values in squeeze_dims: " + squeeze->name()); + + auto name = squeeze->name(); + assert(name.length() > 0); + + auto reshape_shape = node_shape(squeeze); + auto graph = squeeze->graph(); + auto reshape = graph->nodes()->create<luci::CircleReshape>(); + auto shape_const = create_shape_const(graph, reshape_shape); + reshape->name(name + "/Reshape"); + luci::add_origin(reshape, luci::get_origin(squeeze)); + shape_const->name(name + "/Reshape/shape"); + + // graph connection + reshape->tensor(input); + reshape->shape(shape_const); + replace(squeeze).with(reshape); + + return true; +} + +} // namespace + +namespace luci +{ + +/** + * BEFORE + * | + * [CircleNode] + * | + * [CircleSqueeze] + * | + * [CircleNode] + * | + * + * AFTER + * | + * [CircleNode] [CircleConst] + * | \ / + * [CircleSqueeze] [CircleReshape] + * | + * [CircleNode] + * | + */ +bool SubstituteSqueezeToReshapePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto squeeze = dynamic_cast<luci::CircleSqueeze *>(node)) + { + if (substitute_squeeze_to_reshape(squeeze)) + changed = true; + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.test.cpp b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.test.cpp new file mode 100644 index 000000000..d917af678 --- /dev/null +++ b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.test.cpp @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "luci/Pass/SubstituteSqueezeToReshapePass.h" +#include "luci/Pass/CircleShapeInferencePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +using uilist = std::initializer_list<uint32_t>; +using ilist = std::initializer_list<int32_t>; + +class PassTestGraph +{ +public: + PassTestGraph() = default; + +public: + void init(const uilist shape_in, const uilist shape_out) + { + _graph_input = _g.inputs()->create(); + _graph_output = _g.outputs()->create(); + + _input = _g.nodes()->create<luci::CircleInput>(); + _input->shape(shape_in); + _input->shape_status(luci::ShapeStatus::VALID); + _input->name("input"); + + _output = _g.nodes()->create<luci::CircleOutput>(); + _output->shape(shape_out); + _output->shape_status(luci::ShapeStatus::VALID); + _output->name("output"); + + _input->index(_graph_input->index()); + _output->index(_graph_output->index()); + + auto input_shape = std::make_unique<loco::TensorShape>(); + set(input_shape.get(), shape_in); + _graph_input->shape(std::move(input_shape)); + + auto output_shape = std::make_unique<loco::TensorShape>(); + set(output_shape.get(), shape_out); + _graph_output->shape(std::move(output_shape)); + } + +protected: + void set(loco::TensorShape *shape, const uilist &values) + { + uint32_t r = 0; + shape->rank(values.size()); + for (auto v : values) + shape->dim(r++).set(v); + } + +public: + loco::Graph *g(void) { return &_g; } + luci::CircleOutput *output(void) { return _output; } + +protected: + loco::Graph _g; + loco::GraphInput *_graph_input = nullptr; + loco::GraphOutput *_graph_output = nullptr; + luci::CircleInput *_input = nullptr; + luci::CircleOutput *_output = nullptr; +}; + +class SubstituteSqueezeToReshapeGraph : public PassTestGraph +{ +public: + SubstituteSqueezeToReshapeGraph() = default; + +public: + void init(const uilist shape_in, const uilist shape_out, const ilist squeeze_dims) + { + PassTestGraph::init(shape_in, shape_out); + + _squeeze = _g.nodes()->create<luci::CircleSqueeze>(); + _squeeze->input(_input); + _squeeze->squeeze_dims(squeeze_dims); + _squeeze->name("squeeze"); + + _output->from(_squeeze); + } + +protected: + luci::CircleSqueeze *_squeeze = nullptr; +}; + +class SubstituteSqueezeToReshapeTest : public ::testing::Test +{ +public: + SubstituteSqueezeToReshapeTest() = default; + + void run_pass(void) + { + while (_shapeinf.run(_graph.g()) || _pass.run(_graph.g())) + ; + } + +protected: + SubstituteSqueezeToReshapeGraph _graph; + luci::SubstituteSqueezeToReshapePass _pass; + luci::CircleShapeInferencePass _shapeinf; +}; + +} // namespace + +TEST(SubstituteSqueezeToReshapePassTest, name) +{ + luci::SubstituteSqueezeToReshapePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(SubstituteSqueezeToReshapeTest, simple_with_squeeze_dims) +{ + _graph.init({1, 16, 1, 1}, {1, 16}, {2, 3}); + + run_pass(); + + auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from()); + auto squeeze = dynamic_cast<luci::CircleSqueeze *>(_graph.output()->from()); + ASSERT_NE(nullptr, reshape); + ASSERT_EQ(nullptr, squeeze); + auto reshape_shape = loco::must_cast<luci::CircleConst *>(reshape->shape()); + ASSERT_EQ(2, reshape_shape->size<loco::DataType::S32>()); + ASSERT_EQ(1, reshape_shape->at<loco::DataType::S32>(0)); + ASSERT_EQ(16, reshape_shape->at<loco::DataType::S32>(1)); +} + +TEST_F(SubstituteSqueezeToReshapeTest, simple_without_squeeze_dims) +{ + _graph.init({1, 16, 1, 1}, {16}, {}); + + run_pass(); + + auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from()); + auto squeeze = dynamic_cast<luci::CircleSqueeze *>(_graph.output()->from()); + ASSERT_NE(nullptr, reshape); + ASSERT_EQ(nullptr, squeeze); + auto reshape_shape = loco::must_cast<luci::CircleConst *>(reshape->shape()); + ASSERT_EQ(1, reshape_shape->size<loco::DataType::S32>()); + ASSERT_EQ(16, reshape_shape->at<loco::DataType::S32>(0)); +} + +TEST_F(SubstituteSqueezeToReshapeTest, input_with_0_dims) +{ + _graph.init({1, 16, 0, 1}, {16, 0}, {}); + + run_pass(); + + auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from()); + auto squeeze = dynamic_cast<luci::CircleSqueeze *>(_graph.output()->from()); + ASSERT_NE(nullptr, reshape); + ASSERT_EQ(nullptr, squeeze); + auto reshape_shape = loco::must_cast<luci::CircleConst *>(reshape->shape()); + ASSERT_EQ(2, reshape_shape->size<loco::DataType::S32>()); + ASSERT_EQ(16, reshape_shape->at<loco::DataType::S32>(0)); + ASSERT_EQ(0, reshape_shape->at<loco::DataType::S32>(1)); +} + +TEST_F(SubstituteSqueezeToReshapeTest, nothing_to_squeeze) +{ + _graph.init({2, 16, 16, 3}, {2, 16, 16, 3}, {}); + + run_pass(); + + auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from()); + auto squeeze = dynamic_cast<luci::CircleSqueeze *>(_graph.output()->from()); + ASSERT_NE(nullptr, reshape); + ASSERT_EQ(nullptr, squeeze); +} + +TEST_F(SubstituteSqueezeToReshapeTest, all_to_squeeze) +{ + _graph.init({1, 1}, {}, {}); + + run_pass(); + + auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from()); + auto squeeze = dynamic_cast<luci::CircleSqueeze *>(_graph.output()->from()); + ASSERT_NE(nullptr, reshape); + ASSERT_EQ(nullptr, squeeze); +} + +TEST_F(SubstituteSqueezeToReshapeTest, wrong_squeeze_dims_NEG) +{ + _graph.init({1, 16, 1, 1}, {1, 16, 1, 1}, {1}); + + // shape inference will throw for invalid squeeze_dims + EXPECT_THROW(run_pass(), std::exception); +} diff --git a/compiler/luci/pass/src/SubstituteTransposeToReshapePass.cpp b/compiler/luci/pass/src/SubstituteTransposeToReshapePass.cpp new file mode 100644 index 000000000..dfd5e6cf2 --- /dev/null +++ b/compiler/luci/pass/src/SubstituteTransposeToReshapePass.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/SubstituteTransposeToReshapePass.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> + +namespace +{ + +/** + * @brief Convert transpose op in a certain condition to reshape op + * @details Convert transpose op if it have condition below + * 1. have a CircleConst perm value. + * 2. input have an unknown dimension less then 2 + * 3. the order of shape that except dim value 1 remains same on input and output + * eg) input shape = (126, 201, 1, 1) => (126, 201) + * output shape = (1, 126, 1, 201) => (126, 201) + */ +bool substitute_transpose_to_reshape(luci::CircleTranspose *node) +{ + auto perm_const = dynamic_cast<luci::CircleConst *>(node->perm()); + if (perm_const == nullptr) + return false; + + assert(perm_const->dtype() == loco::DataType::S32); + + auto input_node = loco::must_cast<luci::CircleNode *>(node->a()); + if (perm_const->dim(0).value() != input_node->rank()) + return false; + + // If input have more than 2 unknown dimension, transpose will not be changed. + int count = 0; + for (uint32_t i = 0; i < input_node->rank(); i++) + if (!input_node->dim(i).known()) + count++; + if (count > 1) + return false; + + uint32_t idx = 0; + auto size_items = perm_const->size<loco::DataType::S32>(); + for (uint32_t i = 0; i < size_items; i++) + { + assert(perm_const->at<loco::DataType::S32>(i) >= 0 && + perm_const->at<loco::DataType::S32>(i) < static_cast<int32_t>(input_node->rank())); + const auto perm_value = static_cast<uint32_t>(perm_const->at<loco::DataType::S32>(i)); + if (input_node->dim(perm_value).known() && input_node->dim(perm_value).value() == 1) + continue; + // To check idx values are increasing + if (idx > perm_value) + return false; + idx = perm_value; + } + + auto name = node->name(); + assert(name.length() > 0); + + auto new_const_node = node->graph()->nodes()->create<luci::CircleConst>(); + new_const_node->dtype(loco::DataType::S32); + new_const_node->size<loco::DataType::S32>(size_items); + new_const_node->shape_status(luci::ShapeStatus::VALID); + new_const_node->rank(1); + new_const_node->dim(0).set(size_items); + for (uint32_t i = 0; i < size_items; i++) + { + if (input_node->dim(static_cast<uint32_t>(perm_const->at<loco::DataType::S32>(i))).known()) + new_const_node->at<loco::DataType::S32>(i) = static_cast<int32_t>( + input_node->dim(static_cast<uint32_t>(perm_const->at<loco::DataType::S32>(i))).value()); + else + new_const_node->at<loco::DataType::S32>(i) = -1; + } + + auto new_reshape_node = node->graph()->nodes()->create<luci::CircleReshape>(); + new_reshape_node->tensor(input_node); + new_reshape_node->shape(new_const_node); + new_reshape_node->name(name + "/Reshape"); + luci::add_origin(new_reshape_node, luci::get_origin(node)); + new_const_node->name(name + "/Reshape/shape"); + + replace(node).with(new_reshape_node); + return true; +} + +} // namespace + +namespace luci +{ + +/** + * BEFORE + * + * [CircleNode] [CircleConst] + * \ / + * [CircleTranspose] + * | + * [CircleNode] + * + * AFTER + * + * [CircleNode] [CircleConst] + * \ / + * [CircleReshape] + * | + * [CircleNode] + * + */ +bool SubstituteTransposeToReshapePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto circle_node = dynamic_cast<luci::CircleTranspose *>(node)) + { + if (substitute_transpose_to_reshape(circle_node)) + { + changed = true; + } + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/SubstituteTransposeToReshapePass.test.cpp b/compiler/luci/pass/src/SubstituteTransposeToReshapePass.test.cpp new file mode 100644 index 000000000..f81f7e615 --- /dev/null +++ b/compiler/luci/pass/src/SubstituteTransposeToReshapePass.test.cpp @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "luci/Pass/SubstituteTransposeToReshapePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +class SubstituteTransposeToReshapeTest : public ::testing::Test +{ +public: + SubstituteTransposeToReshapeTest() {} + + void buildGraph(const std::initializer_list<uint32_t> shape, const std::vector<int32_t> perm) + { + // Input Create. + input = g.nodes()->create<luci::CircleInput>(); + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + input->shape_status(luci::ShapeStatus::VALID); + input->rank(shape.size()); + input->shape(shape); + input->name("input"); + + // Permutation Create. + auto perm_const = g.nodes()->create<luci::CircleConst>(); + perm_const->dtype(loco::DataType::S32); + perm_const->size<loco::DataType::S32>(perm.size()); + perm_const->shape_status(luci::ShapeStatus::VALID); + perm_const->rank(1); + perm_const->dim(0).set(perm.size()); + for (uint32_t i = 0; i < static_cast<uint32_t>(perm.size()); i++) + { + perm_const->at<loco::DataType::S32>(i) = perm.at(i); + } + perm_const->name("perm_const"); + + // Transpose Create. + auto transpose_node = g.nodes()->create<luci::CircleTranspose>(); + transpose_node->a(input); + transpose_node->perm(perm_const); + transpose_node->name("transpose_node"); + + // Output Connect. + output = g.nodes()->create<luci::CircleOutput>(); + output->from(transpose_node); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + output->name("output"); + } + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleOutput *output = nullptr; +}; + +} // namespace + +TEST(SubstituteTransposeToReshapePassTest, name) +{ + luci::SubstituteTransposeToReshapePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(SubstituteTransposeToReshapeTest, simple_case) +{ + // Create graph that tranpose input {126, 201, 1, 1} with permutation {2, 0, 3, 1} + buildGraph({126, 201, 1, 1}, std::vector<int32_t>({2, 0, 3, 1})); + // With this input shape and permutation values, output shape will be [1, 126, 1, 201]. + // The order of non-one values is unchanged (126, 201). + // So this Transpose op can be converted to Reshape op. + luci::SubstituteTransposeToReshapePass pass; + while (pass.run(&g)) + ; + + auto reshape_node = dynamic_cast<luci::CircleReshape *>(output->from()); + auto transpose_node = dynamic_cast<luci::CircleTranspose *>(output->from()); + ASSERT_NE(nullptr, reshape_node); + ASSERT_EQ(nullptr, transpose_node); + auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape()); + ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(0)); + ASSERT_EQ(126, new_shape->at<loco::DataType::S32>(1)); + ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(2)); + ASSERT_EQ(201, new_shape->at<loco::DataType::S32>(3)); +} + +TEST_F(SubstituteTransposeToReshapeTest, failed_to_substitute_NEG) +{ + // Create graph that tranpose input {126, 201, 1, 1} with permutation {2, 1, 3, 0} + buildGraph({126, 201, 1, 1}, std::vector<int32_t>({2, 1, 3, 0})); + // With this input shape and permutation values, output shape will be [1, 201, 1, 126]. + // The order of non-one values is changed (126, 201) -> (201, 126). + // So this Transpose op cannot be converted to Reshape op. + luci::SubstituteTransposeToReshapePass pass; + while (pass.run(&g)) + ; + + auto reshape_node = dynamic_cast<luci::CircleReshape *>(output->from()); + auto transpose_node = dynamic_cast<luci::CircleTranspose *>(output->from()); + ASSERT_EQ(nullptr, reshape_node); + ASSERT_NE(nullptr, transpose_node); +} diff --git a/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.cpp b/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.cpp new file mode 100644 index 000000000..c15a3b676 --- /dev/null +++ b/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/TransformMinMaxToRelu6Pass.h" + +#include "helpers/NodeFiller.h" +#include "helpers/TypeMapper.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> + +namespace +{ + +template <loco::DataType DT> +bool is_scalar_with_value(luci::CircleConst *node, typename loco::DataTypeImpl<DT>::Type val) +{ + if (node->dtype() != DT) + return false; + if (node->rank() != 0) + return false; + if (node->size<DT>() != 1) + return false; + if (node->at<DT>(0) != static_cast<typename loco::DataTypeImpl<DT>::Type>(val)) + return false; + + return true; +} + +/** + * BEFORE + * [CircleNode] + * | + * [CircleMinimum] + * | + * [CircleMaximum] + * | + * [CircleNode] + * + * AFTER + * + * [CircleNode] + * | + * [CircleRelu6] + * | + * [CircleNode] + * + * NOTE Only max(min(input, 6), 0) pattern will be transformed. + */ +template <loco::DataType DT> bool transform_min_max_pattern(luci::CircleMaximum *maxi) +{ + if (not maxi) + return false; + + if (maxi->dtype() != DT) + return false; + + luci::CircleConst *maxi_const = nullptr; + luci::CircleMinimum *mini = nullptr; + + // There are two ways Maximum takes inputs. + // 1. Maximum(x = CircleConst, y = CircleMinimum) + // 2. Maximum(x = CircleMinimum, y = CircleConst) + if (not luci::fill(&maxi_const, &mini).with_commutative_args_of(maxi)) + return false; + + // Maximum constant should be scalar whose value is 0. + if (not is_scalar_with_value<DT>(maxi_const, + static_cast<typename loco::DataTypeImpl<DT>::Type>(0))) + return false; + + luci::CircleConst *mini_const = nullptr; + loco::Node *mini_input = nullptr; + + // There are two ways Miminum takes inputs. + // 1. Miminum(x = CircleNode, y = CircleMinimum) + // 2. Miminum(x = CircleMinimum, y = CircleNode) + if (not luci::fill(&mini_const, &mini_input).with_commutative_args_of(mini)) + return false; + + // Miminum constant should be scalar whose value is 6. + if (not is_scalar_with_value<DT>(mini_const, + static_cast<typename loco::DataTypeImpl<DT>::Type>(6))) + return false; + + auto name = maxi->name(); + assert(name.length() > 0); + + // Create Relu6 op + auto relu6 = mini->graph()->nodes()->create<luci::CircleRelu6>(); + relu6->features(mini_input); + relu6->name(name + "/Relu6"); + luci::add_origin(relu6, luci::composite_origin({luci::get_origin(maxi), luci::get_origin(mini)})); + + replace(maxi).with(relu6); + + return true; +} + +} // namespace + +namespace luci +{ + +bool TransformMinMaxToRelu6Pass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto maxi = dynamic_cast<luci::CircleMaximum *>(node)) + { + if (transform_min_max_pattern<loco::DataType::FLOAT32>(maxi)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.test.cpp b/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.test.cpp new file mode 100644 index 000000000..9755a70cf --- /dev/null +++ b/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.test.cpp @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/TransformMinMaxToRelu6Pass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +/** + * Minimum-Maximum pattern graph + * + * [CircleInput] [CircleConst] + * \ / + * [CircleMinimum] [CircleConst] + * | / + * [CircleMaximum] + * | + * [CircleOutput] + */ +struct MinMaxGraph +{ + loco::Graph _g; + luci::CircleInput *_input = nullptr; + luci::CircleMinimum *_mini = nullptr; + luci::CircleConst *_mini_const = nullptr; + luci::CircleMaximum *_maxi = nullptr; + luci::CircleConst *_maxi_const = nullptr; + luci::CircleOutput *_output = nullptr; +}; + +class TransformMinMaxToRelu6PassTest : public ::testing::Test +{ +protected: + virtual void SetUp() + { + const int N = 1; + const int H = 4; + const int W = 4; + const int C = 3; + + // graph input and output + auto graph_input = _min_max_g._g.inputs()->create(); + auto graph_output = _min_max_g._g.outputs()->create(); + + // CircleInput + _min_max_g._input = _min_max_g._g.nodes()->create<luci::CircleInput>(); + _min_max_g._input->index(graph_input->index()); + _min_max_g._input->shape({N, H, W, C}); + _min_max_g._input->dtype(loco::DataType::FLOAT32); + _min_max_g._input->name("input"); + + // CircleConst + _min_max_g._mini_const = _min_max_g._g.nodes()->create<luci::CircleConst>(); + _min_max_g._mini_const->shape({}); // scalar + _min_max_g._mini_const->dtype(loco::DataType::FLOAT32); + _min_max_g._mini_const->size<loco::DataType::FLOAT32>(1); + _min_max_g._mini_const->at<loco::DataType::FLOAT32>(0) = 6.; + _min_max_g._mini_const->name("mini_const"); + + // CircleMinimum + _min_max_g._mini = _min_max_g._g.nodes()->create<luci::CircleMinimum>(); + _min_max_g._mini->x(_min_max_g._input); + _min_max_g._mini->y(_min_max_g._mini_const); + _min_max_g._mini->shape({N, H, W, C}); + _min_max_g._mini->dtype(loco::DataType::FLOAT32); + _min_max_g._mini->name("mini"); + + // CircleConst + _min_max_g._maxi_const = _min_max_g._g.nodes()->create<luci::CircleConst>(); + _min_max_g._mini_const->shape({}); // scalar + _min_max_g._maxi_const->dtype(loco::DataType::FLOAT32); + _min_max_g._maxi_const->size<loco::DataType::FLOAT32>(1); + _min_max_g._maxi_const->at<loco::DataType::FLOAT32>(0) = 0.; + _min_max_g._maxi_const->name("maxi_const"); + + // CircleMaximum + _min_max_g._maxi = _min_max_g._g.nodes()->create<luci::CircleMaximum>(); + _min_max_g._maxi->x(_min_max_g._mini); + _min_max_g._maxi->y(_min_max_g._maxi_const); + _min_max_g._maxi->shape({N, H, W, C}); + _min_max_g._maxi->dtype(loco::DataType::FLOAT32); + _min_max_g._maxi->name("maxi"); + + // CircleOutput + _min_max_g._output = _min_max_g._g.nodes()->create<luci::CircleOutput>(); + _min_max_g._output->index(graph_output->index()); + _min_max_g._output->from(_min_max_g._maxi); + _min_max_g._output->shape({N, H, W, C}); + _min_max_g._output->dtype(loco::DataType::FLOAT32); + _min_max_g._output->name("output"); + } + +protected: + luci::TransformMinMaxToRelu6Pass _pass; + MinMaxGraph _min_max_g; +}; + +} // namespace + +TEST_F(TransformMinMaxToRelu6PassTest, name) +{ + auto const name = _pass.name(); + ASSERT_NE(nullptr, name); +} + +/** + * Optimized graph looks like below. + * + * [CircleInput] + * | + * [CircleRelu6] + * | + * [CircleOutput] + */ +TEST_F(TransformMinMaxToRelu6PassTest, simple_test) +{ + auto ret = _pass.run(&_min_max_g._g); + EXPECT_TRUE(ret); + + auto relu6 = dynamic_cast<luci::CircleRelu6 *>(_min_max_g._output->from()); + EXPECT_NE(nullptr, relu6); + + auto input = dynamic_cast<luci::CircleInput *>(relu6->features()); + EXPECT_NE(nullptr, input); +} + +TEST_F(TransformMinMaxToRelu6PassTest, wrong_condition_NEG) +{ + _min_max_g._maxi_const->at<loco::DataType::FLOAT32>(0) = 2.; + + auto ret = _pass.run(&_min_max_g._g); + + EXPECT_FALSE(ret); +} diff --git a/compiler/luci/pass/src/TypeInferencePass.cpp b/compiler/luci/pass/src/TypeInferencePass.cpp deleted file mode 100644 index 63744045c..000000000 --- a/compiler/luci/pass/src/TypeInferencePass.cpp +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "luci/Pass/TypeInferencePass.h" - -#include <luci/IR/CircleDialect.h> -#include <luci/Service/CircleTypeInferenceRule.h> - -#include <loco.h> -#include <loco/IR/CanonicalDialect.h> -#include <loco/Service/TypeInference.h> - -namespace luci -{ - -bool TypeInferencePass::run(luci::Module *m) -{ - bool changed = false; - - for (size_t g = 0; g < m->size(); ++g) - { - if (run(m->graph(g))) - changed = true; - } - - return changed; -} - -bool TypeInferencePass::run(loco::Graph *g) -{ - loco::CanonicalTypeInferenceRule canonical_rule; - luci::CircleTypeInferenceRule circle_rule; - - loco::MultiDialectTypeInferenceRule rules; - - rules.bind(loco::CanonicalDialect::get(), &canonical_rule) - .bind(luci::CircleDialect::get(), &circle_rule); - - return loco::apply(&rules).to(g); -} - -} // namespace luci diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h new file mode 100644 index 000000000..32f0d1a34 --- /dev/null +++ b/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h @@ -0,0 +1,401 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__ +#define __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__ + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> +#include <luci/Pass/QuantizationParameters.h> + +using Granularity = luci::QuantizationGranularity; + +// This macro is undef at the end of the file +#define RETURN_FALSE_UNLESS(ARG) \ + if (not(ARG)) \ + { \ + return false; \ + } + +namespace luci +{ + +/** + * @brief Verify the granualrity of channel-wise quantized node + * @details + * + * Targets to verify + * - node's output (i.e., node itself) + * - node's inputs + */ +struct VerifyQuantizedNodeChannelWiseGranularity final : public luci::CircleNodeVisitor<bool> +{ +private: + bool is_lwq(const loco::Node *node) + { + auto circle_node = loco::must_cast<const luci::CircleNode *>(node); + + if (circle_node->quantparam() == nullptr) + return false; + + if (circle_node->quantparam()->scale.size() != 1) + return false; + + if (circle_node->quantparam()->zerop.size() != 1) + return false; + + return true; + } + + uint32_t rank(const loco::Node *node) + { + auto circle_node = loco::must_cast<const luci::CircleNode *>(node); + return circle_node->rank(); + } + + bool is_cwq_const(const loco::Node *node, uint32_t channel_dim) + { + auto circle_node = loco::must_cast<const luci::CircleConst *>(node); + + assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS + auto channel_size = circle_node->dim(channel_dim).value(); + + if (circle_node->quantparam() == nullptr) + return false; + + if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim)) + return false; + + if (circle_node->quantparam()->scale.size() != channel_size) + return false; + + if (circle_node->quantparam()->zerop.size() != channel_size) + return false; + + return true; + } + +private: + bool visit(const luci::CircleConv2D *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0)) + RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1)) + return true; + } + + bool visit(const luci::CircleConcatenation *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + for (uint32_t i = 0; i < node->numValues(); i++) + { + RETURN_FALSE_UNLESS(is_lwq(node->values(i))); + } + return true; + } + + bool visit(const luci::CircleDepthToSpace *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + return true; + } + + bool visit(const luci::CircleDepthwiseConv2D *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3)) + RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1)) + return true; + } + + bool visit(const luci::CircleInstanceNorm *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1)) + RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1)) + return true; + } + + bool visit(const luci::CirclePad *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + return true; + } + + bool visit(const luci::CirclePRelu *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1)) + return true; + } + + bool visit(const luci::CircleTransposeConv *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->outBackprop())) + RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0)) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1)) + + return true; + } + + bool visit(const luci::CircleFullyConnected *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0)) + RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1)) + return true; + } + + bool visit(const luci::CircleAdd *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleAveragePool2D *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->value())); + return true; + } + + bool visit(const luci::CircleLogicalOr *) + { + // Logical OR has bool-type inputs and output + // Nothing to be checked + return true; + } + + bool visit(const luci::CircleMaxPool2D *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->value())); + return true; + } + + bool visit(const luci::CircleMean *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleMul *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleNotEqual *node) + { + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleRelu *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->features())); + return true; + } + + bool visit(const luci::CircleReshape *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->tensor())); + return true; + } + + bool visit(const luci::CircleLogistic *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + return true; + } + + bool visit(const luci::CircleSoftmax *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->logits())); + return true; + } + + bool visit(const luci::CircleSpaceToBatchND *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleSpaceToDepth *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleSlice *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleSplit *node) + { + // node's output is the input of CircleSplitOut, thus not quantized + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleSplitOut *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + return true; + } + + bool visit(const luci::CircleStridedSlice *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleArgMax *node) + { + // node's output is index, thus not quantized + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleBatchToSpaceND *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleTanh *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + return true; + } + + bool visit(const luci::CircleTranspose *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->a())); + return true; + } + + bool visit(const luci::CircleFloor *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + return true; + } + + bool visit(const luci::CircleGreater *node) + { + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleGreaterEqual *node) + { + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleDiv *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleFloorDiv *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleRsqrt *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + return true; + } + + bool visit(const luci::CircleSqrt *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + return true; + } + + bool visit(const luci::CircleElu *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->features())); + return true; + } + + bool visit(const luci::CirclePow *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleResizeBilinear *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + // TODO: Implement more Ops + + bool visit(const luci::CircleNode *) { return true; } +}; + +} // namespace luci + +#undef RETURN_FALSE_UNLESS + +#endif // __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__ diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h new file mode 100644 index 000000000..1e6fd53c0 --- /dev/null +++ b/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h @@ -0,0 +1,388 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__ +#define __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__ + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> +#include <luci/Pass/QuantizationParameters.h> + +using Granularity = luci::QuantizationGranularity; + +// This macro is undef at the end of the file +#define RETURN_FALSE_UNLESS(ARG) \ + if (not(ARG)) \ + { \ + return false; \ + } + +namespace luci +{ + +/** + * @brief Verify the granualrity of layer-wise quantized node + * @details + * + * Targets to verify + * - node's output (i.e., node itself) + * - node's inputs + */ +struct VerifyQuantizedNodeLayerWiseGranularity final : public luci::CircleNodeVisitor<bool> +{ +private: + bool is_lwq(const loco::Node *node) + { + auto circle_node = loco::must_cast<const luci::CircleNode *>(node); + + if (circle_node->quantparam() == nullptr) + return false; + + if (circle_node->quantparam()->scale.size() != 1) + return false; + + if (circle_node->quantparam()->zerop.size() != 1) + return false; + + return true; + } + + bool is_lwq_const(const loco::Node *node) + { + auto circle_node = loco::must_cast<const luci::CircleConst *>(node); + + if (circle_node->quantparam() == nullptr) + return false; + + if (circle_node->quantparam()->scale.size() != 1) + return false; + + if (circle_node->quantparam()->zerop.size() != 1) + return false; + + return true; + } + +private: + bool visit(const luci::CircleConv2D *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_lwq_const(node->filter())) + RETURN_FALSE_UNLESS(is_lwq_const(node->bias())) + return true; + } + + bool visit(const luci::CircleConcatenation *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + for (uint32_t i = 0; i < node->numValues(); i++) + { + RETURN_FALSE_UNLESS(is_lwq(node->values(i))); + } + return true; + } + + bool visit(const luci::CircleDepthToSpace *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + return true; + } + + bool visit(const luci::CircleDepthwiseConv2D *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_lwq_const(node->filter())) + RETURN_FALSE_UNLESS(is_lwq_const(node->bias())) + return true; + } + + bool visit(const luci::CircleInstanceNorm *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_lwq_const(node->gamma())) + RETURN_FALSE_UNLESS(is_lwq_const(node->beta())) + return true; + } + + bool visit(const luci::CirclePad *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + return true; + } + + bool visit(const luci::CirclePRelu *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_lwq_const(node->alpha())) + return true; + } + + bool visit(const luci::CircleTransposeConv *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->outBackprop())) + RETURN_FALSE_UNLESS(is_lwq_const(node->filter())) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + RETURN_FALSE_UNLESS(is_lwq_const(node->bias())) + return true; + } + + bool visit(const luci::CircleFullyConnected *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_lwq_const(node->weights())) + RETURN_FALSE_UNLESS(is_lwq_const(node->bias())) + return true; + } + + bool visit(const luci::CircleAdd *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleAveragePool2D *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->value())); + return true; + } + + bool visit(const luci::CircleLogicalOr *) + { + // Logical OR has bool-type inputs and output + // Nothing to be checked + return true; + } + + bool visit(const luci::CircleMaxPool2D *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->value())); + return true; + } + + bool visit(const luci::CircleMean *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleMul *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleNotEqual *node) + { + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleRelu *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->features())); + return true; + } + + bool visit(const luci::CircleReshape *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->tensor())); + return true; + } + + bool visit(const luci::CircleLogistic *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + return true; + } + + bool visit(const luci::CircleSoftmax *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->logits())); + return true; + } + + bool visit(const luci::CircleSpaceToBatchND *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleSpaceToDepth *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleSlice *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleSplit *node) + { + // node's output is the input of CircleSplitOut, thus not quantized + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleSplitOut *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + return true; + } + + bool visit(const luci::CircleStridedSlice *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleArgMax *node) + { + // node's output is index, thus not quantized + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleBatchToSpaceND *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + bool visit(const luci::CircleTanh *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + return true; + } + + bool visit(const luci::CircleTranspose *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->a())); + return true; + } + + bool visit(const luci::CircleFloor *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + return true; + } + + bool visit(const luci::CircleGreater *node) + { + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleGreaterEqual *node) + { + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleDiv *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleFloorDiv *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleRsqrt *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + return true; + } + + bool visit(const luci::CircleSqrt *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + return true; + } + + bool visit(const luci::CircleElu *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->features())); + return true; + } + + bool visit(const luci::CirclePow *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->x())); + RETURN_FALSE_UNLESS(is_lwq(node->y())); + return true; + } + + bool visit(const luci::CircleResizeBilinear *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + + // TODO: Implement more Ops + + bool visit(const luci::CircleNode *) { return true; } +}; + +} // namespace luci + +#undef RETURN_FALSE_UNLESS + +#endif // __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__ diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h b/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h new file mode 100644 index 000000000..e05d8325f --- /dev/null +++ b/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h @@ -0,0 +1,375 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_VERIFY_QUANTIZED_NODE_S16_TYPE_H__ +#define __LUCI_VERIFY_QUANTIZED_NODE_S16_TYPE_H__ + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> + +using Type = loco::DataType; + +// This macro is undef at the end of the file +#define RETURN_FALSE_UNLESS(ARG) \ + if (not(ARG)) \ + { \ + return false; \ + } + +namespace luci +{ + +/** + * @brief Verify the data type of INT16 quantized node + * @details + * + * Targets to verify + * - node's output (i.e., node itself) + * - node's inputs + */ +struct VerifyQuantizedNodeS16Type final : public luci::CircleNodeVisitor<bool> +{ +private: + bool has_type(const loco::Node *node, Type dtype) + { + auto circle_node = loco::must_cast<const luci::CircleNode *>(node); + return circle_node->dtype() == dtype; + } + +private: + bool visit(const luci::CircleConv2D *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64)) + return true; + } + + bool visit(const luci::CircleConcatenation *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + for (uint32_t i = 0; i < node->numValues(); i++) + { + RETURN_FALSE_UNLESS(has_type(node->values(i), Type::S16)) + } + return true; + } + + bool visit(const luci::CircleDepthToSpace *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + return true; + } + + bool visit(const luci::CircleDepthwiseConv2D *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64)) + return true; + } + + bool visit(const luci::CircleInstanceNorm *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->gamma(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->beta(), Type::S16)) + return true; + } + + bool visit(const luci::CirclePad *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32)) + return true; + } + + bool visit(const luci::CirclePRelu *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->alpha(), Type::S16)) + return true; + } + + bool visit(const luci::CircleTransposeConv *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16)) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + RETURN_FALSE_UNLESS(has_type(bias, Type::S64)) + return true; + } + + bool visit(const luci::CircleFullyConnected *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->weights(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64)) + return true; + } + + bool visit(const luci::CircleAdd *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) + return true; + } + + bool visit(const luci::CircleAveragePool2D *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16)) + return true; + } + + bool visit(const luci::CircleLogicalOr *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::BOOL)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::BOOL)) + return true; + } + + bool visit(const luci::CircleMaxPool2D *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16)) + return true; + } + + bool visit(const luci::CircleMean *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), Type::S32)) + return true; + } + + bool visit(const luci::CircleMul *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) + return true; + } + + bool visit(const luci::CircleNotEqual *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) + return true; + } + + bool visit(const luci::CircleRelu *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->features(), Type::S16)) + return true; + } + + bool visit(const luci::CircleReshape *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::S16)) + luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape()); + if (shape != nullptr) + RETURN_FALSE_UNLESS(has_type(shape, Type::S32)) + return true; + } + + bool visit(const luci::CircleLogistic *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) + return true; + } + + bool visit(const luci::CircleSoftmax *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->logits(), Type::S16)) + return true; + } + + bool visit(const luci::CircleSpaceToBatchND *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + return true; + } + + bool visit(const luci::CircleSpaceToDepth *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + return true; + } + + bool visit(const luci::CircleSlice *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->begin(), Type::S32) || has_type(node->begin(), Type::S64)) + RETURN_FALSE_UNLESS(has_type(node->size(), Type::S32) || has_type(node->size(), Type::S64)) + return true; + } + + bool visit(const luci::CircleSplit *node) + { + // node's output is the input of CircleSplitOut, thus not quantized + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + return true; + } + + bool visit(const luci::CircleSplitOut *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + return true; + } + + bool visit(const luci::CircleStridedSlice *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + return true; + } + + bool visit(const luci::CircleArgMax *node) + { + RETURN_FALSE_UNLESS(has_type(node, node->output_type())) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->dimension(), Type::S32) || + has_type(node->dimension(), Type::S64)) + return true; + } + + bool visit(const luci::CircleBatchToSpaceND *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + return true; + } + + bool visit(const luci::CircleTanh *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) + return true; + } + + bool visit(const luci::CircleTranspose *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->a(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->perm(), Type::S32)) + return true; + } + + bool visit(const luci::CircleFloor *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) + return true; + } + + bool visit(const luci::CircleGreater *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) + return true; + } + + bool visit(const luci::CircleGreaterEqual *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) + return true; + } + + bool visit(const luci::CircleDiv *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) + return true; + } + + bool visit(const luci::CircleFloorDiv *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) + return true; + } + + bool visit(const luci::CircleRsqrt *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) + return true; + } + + bool visit(const luci::CircleSqrt *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) + return true; + } + + bool visit(const luci::CircleElu *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->features(), Type::S16)) + return true; + } + + bool visit(const luci::CirclePow *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) + return true; + } + + bool visit(const luci::CircleResizeBilinear *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::S16)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) + return true; + } + + // TODO: Implement more Ops + + bool visit(const luci::CircleNode *) { return true; } +}; + +} // namespace luci + +#undef RETURN_FALSE_UNLESS + +#endif // __LUCI_VERIFY_QUNTIZED_NODE_S16_TYPE_H__ diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h b/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h new file mode 100644 index 000000000..72ce5b8f8 --- /dev/null +++ b/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h @@ -0,0 +1,375 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__ +#define __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__ + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> + +using Type = loco::DataType; + +// This macro is undef at the end of the file +#define RETURN_FALSE_UNLESS(ARG) \ + if (not(ARG)) \ + { \ + return false; \ + } + +namespace luci +{ + +/** + * @brief Verify the data type of UINT8 quantized node + * @details + * + * Targets to verify + * - node's output (i.e., node itself) + * - node's inputs + */ +struct VerifyQuantizedNodeU8Type final : public luci::CircleNodeVisitor<bool> +{ +private: + bool has_type(const loco::Node *node, Type dtype) + { + auto circle_node = loco::must_cast<const luci::CircleNode *>(node); + return circle_node->dtype() == dtype; + } + +private: + bool visit(const luci::CircleConv2D *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32)) + return true; + } + + bool visit(const luci::CircleConcatenation *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + for (uint32_t i = 0; i < node->numValues(); i++) + { + RETURN_FALSE_UNLESS(has_type(node->values(i), Type::U8)) + } + return true; + } + + bool visit(const luci::CircleDepthToSpace *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + return true; + } + + bool visit(const luci::CircleDepthwiseConv2D *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32)) + return true; + } + + bool visit(const luci::CircleInstanceNorm *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->gamma(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->beta(), Type::U8)) + return true; + } + + bool visit(const luci::CirclePad *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32)) + return true; + } + + bool visit(const luci::CirclePRelu *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->alpha(), Type::U8)) + return true; + } + + bool visit(const luci::CircleTransposeConv *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8)) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + RETURN_FALSE_UNLESS(has_type(bias, Type::S32)) + return true; + } + + bool visit(const luci::CircleFullyConnected *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->weights(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32)) + return true; + } + + bool visit(const luci::CircleAdd *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) + return true; + } + + bool visit(const luci::CircleAveragePool2D *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8)) + return true; + } + + bool visit(const luci::CircleBatchToSpaceND *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + return true; + } + + bool visit(const luci::CircleLogicalOr *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::BOOL)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::BOOL)) + return true; + } + + bool visit(const luci::CircleMaxPool2D *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8)) + return true; + } + + bool visit(const luci::CircleMean *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), Type::S32)) + return true; + } + + bool visit(const luci::CircleMul *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) + return true; + } + + bool visit(const luci::CircleNotEqual *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) + return true; + } + + bool visit(const luci::CircleRelu *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8)) + return true; + } + + bool visit(const luci::CircleReshape *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::U8)) + luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape()); + if (shape != nullptr) + RETURN_FALSE_UNLESS(has_type(shape, Type::S32)) + return true; + } + + bool visit(const luci::CircleLogistic *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) + return true; + } + + bool visit(const luci::CircleSoftmax *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->logits(), Type::U8)) + return true; + } + + bool visit(const luci::CircleSpaceToBatchND *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + return true; + } + + bool visit(const luci::CircleSpaceToDepth *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + return true; + } + + bool visit(const luci::CircleSlice *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->begin(), Type::S32) || has_type(node->begin(), Type::S64)) + RETURN_FALSE_UNLESS(has_type(node->size(), Type::S32) || has_type(node->size(), Type::S64)) + return true; + } + + bool visit(const luci::CircleSplit *node) + { + // node's output is the input of CircleSplitOut, thus not quantized + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + return true; + } + + bool visit(const luci::CircleSplitOut *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + return true; + } + + bool visit(const luci::CircleStridedSlice *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + return true; + } + + bool visit(const luci::CircleArgMax *node) + { + RETURN_FALSE_UNLESS(has_type(node, node->output_type())) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->dimension(), Type::S32) || + has_type(node->dimension(), Type::S64)) + return true; + } + + bool visit(const luci::CircleTanh *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) + return true; + } + + bool visit(const luci::CircleTranspose *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->a(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->perm(), Type::S32)) + return true; + } + + bool visit(const luci::CircleFloor *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) + return true; + } + + bool visit(const luci::CircleGreater *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) + return true; + } + + bool visit(const luci::CircleGreaterEqual *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) + return true; + } + + bool visit(const luci::CircleDiv *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) + return true; + } + + bool visit(const luci::CircleFloorDiv *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) + return true; + } + + bool visit(const luci::CircleRsqrt *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) + return true; + } + + bool visit(const luci::CircleSqrt *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) + return true; + } + + bool visit(const luci::CircleElu *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8)) + return true; + } + + bool visit(const luci::CirclePow *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) + return true; + } + + bool visit(const luci::CircleResizeBilinear *node) + { + RETURN_FALSE_UNLESS(has_type(node, Type::U8)) + RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) + return true; + } + + // TODO: Implement more Ops + + bool visit(const luci::CircleNode *) { return true; } +}; + +} // namespace luci + +#undef RETURN_FALSE_UNLESS + +#endif // __LUCI_VERIFY_QUNTIZED_NODE_U8_TYPE_H__ diff --git a/compiler/luci/pass/src/helpers/InferenceCandidates.cpp b/compiler/luci/pass/src/helpers/InferenceCandidates.cpp new file mode 100644 index 000000000..2c8565932 --- /dev/null +++ b/compiler/luci/pass/src/helpers/InferenceCandidates.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "InferenceCandidates.h" + +#include <luci/IR/DeadNodeQueryService.h> + +namespace luci +{ + +std::vector<loco::Node *> inference_candidates(loco::Graph *g) +{ + auto candidates = loco::postorder_traversal(loco::output_nodes(g)); + + for (auto node : loco::all_nodes(g)) + { + // already included as candidate + if (std::find(candidates.begin(), candidates.end(), node) != candidates.end()) + continue; + + // As the node is not used for both graph output and multiple output operation, + // it cannot be candidate. + if (node->dialect()->service<DeadNodeQueryServiceImpl>()->isDeadNode(node)) + continue; + + candidates.emplace_back(node); + } + + return candidates; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/helpers/InferenceCandidates.h b/compiler/luci/pass/src/helpers/InferenceCandidates.h new file mode 100644 index 000000000..f27e4fe60 --- /dev/null +++ b/compiler/luci/pass/src/helpers/InferenceCandidates.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_INFERENCE_CANDIDATES_H__ +#define __LUCI_INFERENCE_CANDIDATES_H__ + +#include <loco.h> + +#include <vector> + +namespace luci +{ + +/** + * @brief Enumerate all the nodes whose shape/dtype should be inferenced to export graph. + */ +std::vector<loco::Node *> inference_candidates(loco::Graph *g); + +} // namespace luci + +#endif // __LUCI_INFERENCE_CANDIDATES_H__ diff --git a/compiler/luci/pass/src/helpers/InferenceCandidates.test.cpp b/compiler/luci/pass/src/helpers/InferenceCandidates.test.cpp new file mode 100644 index 000000000..e34421f5e --- /dev/null +++ b/compiler/luci/pass/src/helpers/InferenceCandidates.test.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "InferenceCandidates.h" +#include "luci/IR/CircleNode.h" + +#include <algorithm> + +#include <gtest/gtest.h> + +namespace +{ + +bool contains(const std::vector<loco::Node *> &vec, loco::Node *val) +{ + return std::any_of(vec.begin(), vec.end(), [val](loco::Node *node) { return node == val; }); +} + +} // namespace + +TEST(LuciPassHelpersInferenceCandidates, inference_candidates) +{ + auto g = loco::make_graph(); + + // Create nodes + auto input = g->nodes()->create<luci::CircleInput>(); + auto split = g->nodes()->create<luci::CircleSplit>(); + auto split_out1 = g->nodes()->create<luci::CircleSplitOut>(); + auto split_out2 = g->nodes()->create<luci::CircleSplitOut>(); + auto split_dim = g->nodes()->create<luci::CircleConst>(); + auto output = g->nodes()->create<luci::CircleOutput>(); + + // Build up initial graph + auto graph_input1 = g->inputs()->create(); + input->index(graph_input1->index()); + + split->split_dim(split_dim); + split->input(input); + split->num_split(2); + + split_out1->input(split); + split_out1->index(0); + + split_out2->input(split); + split_out2->index(1); + + auto graph_output = g->outputs()->create(); + output->from(split_out1); + output->index(graph_output->index()); + + auto s = luci::inference_candidates(g.get()); + + ASSERT_EQ(6, s.size()); + ASSERT_TRUE(contains(s, input)); + ASSERT_TRUE(contains(s, split)); + ASSERT_TRUE(contains(s, split_out1)); + ASSERT_TRUE(contains(s, split_out2)); + ASSERT_TRUE(contains(s, split_dim)); + ASSERT_TRUE(contains(s, output)); +} + +TEST(LuciPassHelpersInferenceCandidates, inference_candidates_NEG) +{ + auto g = loco::make_graph(); + + // Create nodes + auto input = g->nodes()->create<luci::CircleInput>(); + auto split = g->nodes()->create<luci::CircleSplit>(); + auto split_out1 = g->nodes()->create<luci::CircleSplitOut>(); + auto split_out2 = g->nodes()->create<luci::CircleSplitOut>(); + auto split_dim = g->nodes()->create<luci::CircleConst>(); + auto relu1 = g->nodes()->create<luci::CircleRelu>(); + auto relu2 = g->nodes()->create<luci::CircleRelu>(); + auto output = g->nodes()->create<luci::CircleOutput>(); + + // Build up initial graph + auto graph_input1 = g->inputs()->create(); + input->index(graph_input1->index()); + + split->split_dim(split_dim); + split->input(input); + split->num_split(2); + + split_out1->input(split); + split_out1->index(0); + + split_out2->input(split); + split_out2->index(1); + + relu1->features(split_out2); + + relu2->features(input); + + auto graph_output = g->outputs()->create(); + output->from(split_out1); + output->index(graph_output->index()); + + auto s = luci::inference_candidates(g.get()); + + ASSERT_EQ(6, s.size()); + ASSERT_TRUE(contains(s, input)); + ASSERT_TRUE(contains(s, split)); + ASSERT_TRUE(contains(s, split_out1)); + ASSERT_TRUE(contains(s, split_out2)); + ASSERT_TRUE(contains(s, split_dim)); + ASSERT_TRUE(contains(s, output)); + ASSERT_FALSE(contains(s, relu1)); + ASSERT_FALSE(contains(s, relu2)); +} diff --git a/compiler/luci/pass/src/helpers/NodeFiller.cpp b/compiler/luci/pass/src/helpers/NodeFiller.cpp new file mode 100644 index 000000000..b1416655d --- /dev/null +++ b/compiler/luci/pass/src/helpers/NodeFiller.cpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "NodeFiller.h" + +// NOTE Do NOT delete this file; this file enforces compiler to check whether 'NodeFiller.h' is +// complete. diff --git a/compiler/luci/pass/src/helpers/NodeFiller.h b/compiler/luci/pass/src/helpers/NodeFiller.h new file mode 100644 index 000000000..b80f085b0 --- /dev/null +++ b/compiler/luci/pass/src/helpers/NodeFiller.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace luci +{ + +/** + * INTRODUCTION + * Binary operation f(x,y) is 'commutative' when + * f(x,y) == f(y,x) holds for all x, y. + * For examples, ADD, MUL and SQUARED_DIFFERENCE are commutative. + * These helpers make it easy to find commutative arguments of commutative node. + * + * HOW TO USE + * COMM_NODE *node; + * ARG_TYPE_1 *arg1; + * ARG_TYPE_2 *arg2; + * + * bool ok = fill(&arg1, &arg2).with_commutative_args_of(node); + * + * Result + * If 'node's commutative argument types are actually {ARG_TYPE_1, ARG_TYPE_2} + * (as a set), 'arg1' and 'arg2' set as actual 'node's arguments with matching + * type, and return value 'ok' is true. + * Otherwise, 'arg1' and 'arg2' not changed, 'ok' is false. + */ + +template <class ARG_TYPE_1, class ARG_TYPE_2> class NodeFiller final +{ +public: + NodeFiller(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) : _arg_1(arg_1), _arg_2(arg_2) + { + // DO NOTHING + } + + /** + * @return true When 'node's argument types are 'ARG_TYPE_1' and 'ARG_TYPE_2' + * In such case, it assign '_arg_1' and '_arg_2' to actual arguments + * + * @return false When 'node's argument types are NOT matched with 'ARG_TYPE_*' + * In such case, it does not amend '_arg_1' and '_arg_2' + * + * @require COMM_NODE has member x() and y() + */ + template <class COMM_NODE> bool with_commutative_args_of(const COMM_NODE *node); + +private: + ARG_TYPE_1 **_arg_1; + ARG_TYPE_2 **_arg_2; +}; + +template <class ARG_TYPE_1, class ARG_TYPE_2> +inline NodeFiller<ARG_TYPE_1, ARG_TYPE_2> fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) +{ + return NodeFiller<ARG_TYPE_1, ARG_TYPE_2>{arg_1, arg_2}; +} + +template <class ARG_TYPE_1, class ARG_TYPE_2> +template <class COMM_NODE> +bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(const COMM_NODE *node) +{ + // Case 1) X == ARG_TYPE_1 / Y == ARG_TYPE_2 + { + auto x = dynamic_cast<ARG_TYPE_1 *>(node->x()); + auto y = dynamic_cast<ARG_TYPE_2 *>(node->y()); + + if (x && y) + { + *_arg_1 = x; + *_arg_2 = y; + return true; + } + } + + // Case 2) X == ARG_TYPE_2 / Y == ARG_TYPE_1 + { + auto x = dynamic_cast<ARG_TYPE_2 *>(node->x()); + auto y = dynamic_cast<ARG_TYPE_1 *>(node->y()); + + if (x && y) + { + *_arg_1 = y; + *_arg_2 = x; + return true; + } + } + + return false; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/helpers/NodeFiller.test.cpp b/compiler/luci/pass/src/helpers/NodeFiller.test.cpp new file mode 100644 index 000000000..9bbc7f264 --- /dev/null +++ b/compiler/luci/pass/src/helpers/NodeFiller.test.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +#include "NodeFiller.h" + +TEST(NodeFillerTest, simple_test) +{ + luci::CircleConst maxi_const; + luci::CircleMinimum mini; + luci::CircleMaximum maxi; + maxi.x(&maxi_const); + maxi.y(&mini); + + luci::CircleConst *x = nullptr; + luci::CircleMinimum *y = nullptr; + + EXPECT_TRUE(luci::fill(&x, &y).with_commutative_args_of(&maxi)); + EXPECT_TRUE(x == &maxi_const); + EXPECT_TRUE(y == &mini); + + x = nullptr; + y = nullptr; + + EXPECT_TRUE(luci::fill(&y, &x).with_commutative_args_of(&maxi)); + EXPECT_TRUE(x == &maxi_const); + EXPECT_TRUE(y == &mini); +} + +TEST(NodeFillerTest, wrong_condition_NEG) +{ + luci::CircleConst add_const; + luci::CircleMinimum mini; + luci::CircleAdd add; + add.x(&add_const); + add.y(&mini); + + luci::CircleMul *x = nullptr; + luci::CircleMinimum *y = nullptr; + + EXPECT_FALSE(luci::fill(&x, &y).with_commutative_args_of(&add)); + EXPECT_FALSE(luci::fill(&y, &x).with_commutative_args_of(&add)); +} diff --git a/compiler/luci/pass/src/helpers/Strings.cpp b/compiler/luci/pass/src/helpers/Strings.cpp new file mode 100644 index 000000000..d020f6ddc --- /dev/null +++ b/compiler/luci/pass/src/helpers/Strings.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Strings.h" + +#include <algorithm> + +namespace luci +{ + +bool in_array(const std::string &str, const std::vector<std::string> &array) +{ + return std::find(array.begin(), array.end(), str) != array.end(); +} + +std::string to_string(const std::vector<std::string> &strings) +{ + assert(!strings.empty()); + + std::string res; + for (unsigned int i = 0; i < strings.size() - 1; i++) + res += strings[i] + ", "; + + res += strings[strings.size() - 1]; + return res; +} + +std::string to_lower_case(std::string s) +{ + std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); + return s; +} + +loco::DataType str_to_dtype(const std::string &str) +{ + if (to_lower_case(str).compare("uint8") == 0) + return loco::DataType::U8; + if (to_lower_case(str).compare("uint16") == 0) + return loco::DataType::U16; + if (to_lower_case(str).compare("uint32") == 0) + return loco::DataType::U32; + if (to_lower_case(str).compare("uint64") == 0) + return loco::DataType::U64; + + if (to_lower_case(str).compare("int8") == 0) + return loco::DataType::S8; + if (to_lower_case(str).compare("int16") == 0) + return loco::DataType::S16; + if (to_lower_case(str).compare("int32") == 0) + return loco::DataType::S32; + if (to_lower_case(str).compare("int64") == 0) + return loco::DataType::S64; + + if (to_lower_case(str).compare("float16") == 0) + return loco::DataType::FLOAT16; + if (to_lower_case(str).compare("float32") == 0) + return loco::DataType::FLOAT32; + if (to_lower_case(str).compare("float64") == 0) + return loco::DataType::FLOAT64; + + if (to_lower_case(str).compare("bool") == 0) + return loco::DataType::BOOL; + + return loco::DataType::Unknown; +} + +QuantizationGranularity str_to_granularity(const std::string &str) +{ + if (to_lower_case(str).compare("layer") == 0) + return QuantizationGranularity::LayerWise; + + if (to_lower_case(str).compare("channel") == 0) + return QuantizationGranularity::ChannelWise; + + throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'"); +} + +} // namespace luci diff --git a/compiler/luci/pass/src/helpers/Strings.h b/compiler/luci/pass/src/helpers/Strings.h new file mode 100644 index 000000000..793d137fb --- /dev/null +++ b/compiler/luci/pass/src/helpers/Strings.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PASS_HELPERS_STRINGS_H__ +#define __LUCI_PASS_HELPERS_STRINGS_H__ + +#include "luci/Pass/QuantizationParameters.h" + +#include <loco.h> + +#include <vector> +#include <sstream> +#include <string> + +namespace luci +{ + +bool in_array(const std::string &, const std::vector<std::string> &); + +std::string to_string(const std::vector<std::string> &); + +std::string to_lower_case(std::string); + +loco::DataType str_to_dtype(const std::string &); + +QuantizationGranularity str_to_granularity(const std::string &); + +template <typename T> std::vector<T> csv_to_vector(const std::string &str) +{ + std::vector<T> ret; + std::istringstream is(str); + for (T i; is >> i;) + { + assert(i != ','); + ret.push_back(i); + if (is.peek() == ',') + is.ignore(); + } + return ret; +} + +} // namespace luci + +#endif // __LUCI_PASS_HELPERS_STRINGS_H__ diff --git a/compiler/luci/pass/src/helpers/Strings.test.cpp b/compiler/luci/pass/src/helpers/Strings.test.cpp new file mode 100644 index 000000000..f6bb48951 --- /dev/null +++ b/compiler/luci/pass/src/helpers/Strings.test.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Strings.h" + +#include "luci/Pass/QuantizationParameters.h" + +#include <gtest/gtest.h> + +TEST(StringsTest, str_to_dtype) +{ + ASSERT_EQ(loco::DataType::U8, luci::str_to_dtype("uint8")); + ASSERT_EQ(loco::DataType::U16, luci::str_to_dtype("uint16")); + ASSERT_EQ(loco::DataType::U32, luci::str_to_dtype("uint32")); + ASSERT_EQ(loco::DataType::U64, luci::str_to_dtype("uint64")); + + ASSERT_EQ(loco::DataType::S8, luci::str_to_dtype("int8")); + ASSERT_EQ(loco::DataType::S16, luci::str_to_dtype("int16")); + ASSERT_EQ(loco::DataType::S32, luci::str_to_dtype("int32")); + ASSERT_EQ(loco::DataType::S64, luci::str_to_dtype("int64")); + + ASSERT_EQ(loco::DataType::FLOAT16, luci::str_to_dtype("float16")); + ASSERT_EQ(loco::DataType::FLOAT32, luci::str_to_dtype("float32")); + ASSERT_EQ(loco::DataType::FLOAT64, luci::str_to_dtype("float64")); + + ASSERT_EQ(loco::DataType::BOOL, luci::str_to_dtype("bool")); + + ASSERT_EQ(loco::DataType::Unknown, luci::str_to_dtype("foo")); +} + +TEST(StringsTest, str_to_granularity) +{ + ASSERT_EQ(luci::QuantizationGranularity::LayerWise, luci::str_to_granularity("layer")); + ASSERT_EQ(luci::QuantizationGranularity::ChannelWise, luci::str_to_granularity("channel")); + + EXPECT_THROW(luci::str_to_granularity("foo"), std::runtime_error); +} + +TEST(StringsTest, csv_to_vector_int32) +{ + auto ret = luci::csv_to_vector<int32_t>("1,2,3"); + ASSERT_EQ(3, ret.size()); + ASSERT_EQ(1, ret.at(0)); + ASSERT_EQ(3, ret.at(2)); +} diff --git a/compiler/luci/pass/src/helpers/TypeMapper.cpp b/compiler/luci/pass/src/helpers/TypeMapper.cpp new file mode 100644 index 000000000..ffa0159dd --- /dev/null +++ b/compiler/luci/pass/src/helpers/TypeMapper.cpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TypeMapper.h" + +// NOTE Do NOT delete this file; this file enforces compiler to check whether 'TypeMapper.h' is +// complete. diff --git a/compiler/luci/pass/src/helpers/TypeMapper.h b/compiler/luci/pass/src/helpers/TypeMapper.h new file mode 100644 index 000000000..90760e95b --- /dev/null +++ b/compiler/luci/pass/src/helpers/TypeMapper.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <loco/IR/DataType.h> + +#include <cstdint> + +namespace luci +{ + +/** + * @brief TypeMapper maps between c++ primitive data type and loco::DataType. + */ +template <typename T> struct TypeMapper +{ + static constexpr loco::DataType get() { return loco::DataType::Unknown; } +}; + +template <> struct TypeMapper<float> +{ + static constexpr loco::DataType get() { return loco::DataType::FLOAT32; } +}; + +template <> struct TypeMapper<uint8_t> +{ + static constexpr loco::DataType get() { return loco::DataType::U8; } +}; + +template <> struct TypeMapper<uint16_t> +{ + static constexpr loco::DataType get() { return loco::DataType::U16; } +}; + +template <> struct TypeMapper<uint32_t> +{ + static constexpr loco::DataType get() { return loco::DataType::U32; } +}; + +template <> struct TypeMapper<uint64_t> +{ + static constexpr loco::DataType get() { return loco::DataType::U64; } +}; + +template <> struct TypeMapper<int8_t> +{ + static constexpr loco::DataType get() { return loco::DataType::S8; } +}; + +template <> struct TypeMapper<int16_t> +{ + static constexpr loco::DataType get() { return loco::DataType::S16; } +}; + +template <> struct TypeMapper<int32_t> +{ + static constexpr loco::DataType get() { return loco::DataType::S32; } +}; + +template <> struct TypeMapper<int64_t> +{ + static constexpr loco::DataType get() { return loco::DataType::S64; } +}; + +} // namespace luci diff --git a/compiler/luci/pass/src/helpers/TypeMapper.test.cpp b/compiler/luci/pass/src/helpers/TypeMapper.test.cpp new file mode 100644 index 000000000..a7ac08a63 --- /dev/null +++ b/compiler/luci/pass/src/helpers/TypeMapper.test.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +#include "TypeMapper.h" + +#include <vector> + +namespace +{ + +template <typename T> bool fill_const_node(luci::CircleConst *node, std::vector<T> &data) +{ + if (node->dtype() != luci::TypeMapper<T>::get()) + return false; + + node->size<luci::TypeMapper<T>::get()>(data.size()); + for (uint32_t i = 0; i < data.size(); i++) + { + node->at<luci::TypeMapper<T>::get()>(i) = data.at(i); + } + + return true; +} + +class STRANGER +{ +}; + +} // namespace + +TEST(TypeMapperTest, simple_test) +{ + EXPECT_EQ(loco::DataType::FLOAT32, luci::TypeMapper<float>::get()); + EXPECT_EQ(loco::DataType::U8, luci::TypeMapper<uint8_t>::get()); + EXPECT_EQ(loco::DataType::U16, luci::TypeMapper<uint16_t>::get()); + EXPECT_EQ(loco::DataType::U32, luci::TypeMapper<uint32_t>::get()); + EXPECT_EQ(loco::DataType::U64, luci::TypeMapper<uint64_t>::get()); + EXPECT_EQ(loco::DataType::S8, luci::TypeMapper<int8_t>::get()); + EXPECT_EQ(loco::DataType::S16, luci::TypeMapper<int16_t>::get()); + EXPECT_EQ(loco::DataType::S32, luci::TypeMapper<int32_t>::get()); + EXPECT_EQ(loco::DataType::S64, luci::TypeMapper<int64_t>::get()); +} + +TEST(TypeMapperTest, with_template_test) +{ + std::vector<int32_t> int32_vec{0, 1, 2, 3, 4, 5, 6, 7}; + luci::CircleConst const_node; + const_node.dtype(loco::DataType::S32); + EXPECT_TRUE(fill_const_node(&const_node, int32_vec)); + EXPECT_EQ(8, const_node.size<loco::DataType::S32>()); + EXPECT_EQ(0, const_node.at<loco::DataType::S32>(0)); + EXPECT_EQ(1, const_node.at<loco::DataType::S32>(1)); + EXPECT_EQ(2, const_node.at<loco::DataType::S32>(2)); + EXPECT_EQ(3, const_node.at<loco::DataType::S32>(3)); + EXPECT_EQ(4, const_node.at<loco::DataType::S32>(4)); + EXPECT_EQ(5, const_node.at<loco::DataType::S32>(5)); + EXPECT_EQ(6, const_node.at<loco::DataType::S32>(6)); + EXPECT_EQ(7, const_node.at<loco::DataType::S32>(7)); + + std::vector<float> f32_vec{0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; + const_node.dtype(loco::DataType::FLOAT32); + EXPECT_FALSE(fill_const_node(&const_node, int32_vec)); + EXPECT_TRUE(fill_const_node(&const_node, f32_vec)); + EXPECT_EQ(6, const_node.size<loco::DataType::FLOAT32>()); + EXPECT_FLOAT_EQ(0.0, const_node.at<loco::DataType::FLOAT32>(0)); + EXPECT_FLOAT_EQ(1.1, const_node.at<loco::DataType::FLOAT32>(1)); + EXPECT_FLOAT_EQ(2.2, const_node.at<loco::DataType::FLOAT32>(2)); + EXPECT_FLOAT_EQ(3.3, const_node.at<loco::DataType::FLOAT32>(3)); + EXPECT_FLOAT_EQ(4.4, const_node.at<loco::DataType::FLOAT32>(4)); + EXPECT_FLOAT_EQ(5.5, const_node.at<loco::DataType::FLOAT32>(5)); +} + +TEST(TypeMapperTest, wrong_condition_NEG) +{ + EXPECT_EQ(loco::DataType::Unknown, luci::TypeMapper<STRANGER>::get()); +} diff --git a/compiler/luci/pass/src/test/TestFirstNode.h b/compiler/luci/pass/src/test/TestFirstNode.h new file mode 100644 index 000000000..21f859fcd --- /dev/null +++ b/compiler/luci/pass/src/test/TestFirstNode.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PASS_TEST_FIRST_NODE_H__ +#define __LUCI_PASS_TEST_FIRST_NODE_H__ + +#include <luci/IR/CircleNodes.h> + +#include <loco.h> + +namespace luci +{ +namespace test +{ + +template <class T> T *first_node(loco::Graph *g) +{ + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto target_node = dynamic_cast<T *>(node); + if (target_node != nullptr) + return target_node; + } + return nullptr; +} + +} // namespace test +} // namespace luci + +#endif // __LUCI_PASS_TEST_FIRST_NODE_H__ diff --git a/compiler/luci/pass/src/test/TestFirstNode.test.cpp b/compiler/luci/pass/src/test/TestFirstNode.test.cpp new file mode 100644 index 000000000..b07ac6199 --- /dev/null +++ b/compiler/luci/pass/src/test/TestFirstNode.test.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TestFirstNode.h" + +// This file validates "TestFirstNode.h". Pleaes DO NOT remove this file. diff --git a/compiler/luci/pass/src/test/TestIOGraph.h b/compiler/luci/pass/src/test/TestIOGraph.h new file mode 100644 index 000000000..b1fc41f90 --- /dev/null +++ b/compiler/luci/pass/src/test/TestIOGraph.h @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PASS_TEST_IO_GRAPH_H__ +#define __LUCI_PASS_TEST_IO_GRAPH_H__ + +#include "TestShape.h" + +#include <luci/IR/CircleNodes.h> + +namespace luci +{ +namespace test +{ + +/** + * @brief Graphlet with Inputs and loco::Graph for multiple inputs + * @note Every Graph will have Input(s) and Output(s) + * We put loco::Graph only in IsGraphlet not to declare separate + * class for loco::Graph + */ +template <unsigned N> class TestIsGraphlet +{ +public: + TestIsGraphlet() + { + for (uint32_t n = 0; n < N; ++n) + { + _graph_inputs[n] = nullptr; + _inputs[n] = nullptr; + } + } + +public: + virtual void init(loco::Graph *g, const ShapeU32 shape_in) + { + for (uint32_t n = 0; n < N; ++n) + { + _graph_inputs[n] = g->inputs()->create(); + + _inputs[n] = g->nodes()->create<luci::CircleInput>(); + _inputs[n]->shape(shape_in); + _inputs[n]->shape_status(luci::ShapeStatus::VALID); + _inputs[n]->dtype(loco::DataType::FLOAT32); + _inputs[n]->name("input_" + std::to_string(n)); + + _inputs[n]->index(_graph_inputs[n]->index()); + + auto input_shape = std::make_unique<loco::TensorShape>(); + set_shape_vector(input_shape.get(), shape_in); + _graph_inputs[n]->shape(std::move(input_shape)); + _graph_inputs[n]->dtype(loco::DataType::FLOAT32); + } + } + +public: + loco::Graph *g(void) { return &_g; } + luci::CircleInput *input(int idx) { return _inputs[idx]; } + +protected: + loco::Graph _g; + std::array<loco::GraphInput *, N> _graph_inputs; + std::array<luci::CircleInput *, N> _inputs; +}; + +/** + * @brief Graphlet with one Input + */ +class TestIGraphlet : public TestIsGraphlet<1> +{ +public: + luci::CircleInput *input() { return _inputs[0]; } +}; + +/** + * @brief Graphlet with Outputs for multiple outputs + */ +template <unsigned N> class TestOsGraphlet +{ +public: + TestOsGraphlet() + { + for (uint32_t n = 0; n < N; ++n) + { + _graph_outputs[n] = nullptr; + _outputs[n] = nullptr; + } + } + +public: + virtual void init(loco::Graph *g, const ShapeU32 shape_out) + { + for (uint32_t n = 0; n < N; ++n) + { + _graph_outputs[n] = g->outputs()->create(); + + _outputs[n] = g->nodes()->create<luci::CircleOutput>(); + _outputs[n]->shape(shape_out); + _outputs[n]->shape_status(luci::ShapeStatus::VALID); + _outputs[n]->dtype(loco::DataType::FLOAT32); + _outputs[n]->name("output_" + std::to_string(n)); + + _outputs[n]->index(_graph_outputs[n]->index()); + + auto output_shape = std::make_unique<loco::TensorShape>(); + set_shape_vector(output_shape.get(), shape_out); + _graph_outputs[n]->shape(std::move(output_shape)); + _graph_outputs[n]->dtype(loco::DataType::FLOAT32); + } + } + +public: + luci::CircleOutput *output(int idx) { return _outputs[idx]; } + +protected: + std::array<loco::GraphOutput *, N> _graph_outputs; + std::array<luci::CircleOutput *, N> _outputs; +}; + +/** + * @brief Graphlet with one Output + */ +class TestOGraphlet : public TestOsGraphlet<1> +{ +public: + luci::CircleOutput *output() { return _outputs[0]; } +}; + +/** + * @brief Graph with Input and Output + */ +class TestIOGraph : public TestIGraphlet, public TestOGraphlet +{ +public: + TestIOGraph() = default; + +public: + virtual void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIsGraphlet<1>::init(g(), shape_in); + TestOsGraphlet<1>::init(g(), shape_out); + } +}; + +} // namespace test +} // namespace luci + +#endif // __LUCI_PASS_TEST_IO_GRAPH_H__ diff --git a/compiler/luci/pass/src/test/TestIOGraph.test.cpp b/compiler/luci/pass/src/test/TestIOGraph.test.cpp new file mode 100644 index 000000000..e58a13f2b --- /dev/null +++ b/compiler/luci/pass/src/test/TestIOGraph.test.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TestIOGraph.h" + +// This file validates "TestIOGraph.h". Pleaes DO NOT remove this file. diff --git a/compiler/luci/export/src/TypeBridge.h b/compiler/luci/pass/src/test/TestShape.h index a63fbce54..ccc55c9da 100644 --- a/compiler/luci/export/src/TypeBridge.h +++ b/compiler/luci/pass/src/test/TestShape.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,31 +14,27 @@ * limitations under the License. */ -#ifndef __TYPE_BRIDGE_H__ -#define __TYPE_BRIDGE_H__ +#ifndef __LUCI_PASS_TEST_SHAPE_H__ +#define __LUCI_PASS_TEST_SHAPE_H__ #include <luci/IR/CircleNode.h> -#include <loco.h> +#include <initializer_list> namespace luci { +namespace test +{ -/** - * @brief node_shape() will return loco::TensorShape of CircleNode - */ -loco::TensorShape node_shape(CircleNode *node); +using ShapeU32 = std::initializer_list<uint32_t>; +using ShapeI32 = std::initializer_list<int32_t>; -/** - * @brief node_dtype() will return loco::DataType of CircleNode - */ -loco::DataType node_dtype(CircleNode *node); +void set_shape_vector(loco::TensorShape *shape, const ShapeU32 &values); +void set_shape_vector(luci::CircleConst *const_node, const ShapeI32 &values); -/** - * @brief copy_shape_dtype() will copy shape and dtype inference data to CircleNode - */ -void copy_shape_dtype(loco::Graph *graph); +uint32_t num_elements(const ShapeU32 shape); +} // namespace test } // namespace luci -#endif // __TYPE_BRIDGE_H__ +#endif // __LUCI_PASS_TEST_SHAPE_H__ diff --git a/compiler/luci/pass/src/test/TestShape.test.cpp b/compiler/luci/pass/src/test/TestShape.test.cpp new file mode 100644 index 000000000..39790c614 --- /dev/null +++ b/compiler/luci/pass/src/test/TestShape.test.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TestShape.h" + +/** + * @note This file does not hold any test cases but provides methods for tests + */ + +namespace luci +{ +namespace test +{ + +void set_shape_vector(loco::TensorShape *shape, const ShapeU32 &values) +{ + uint32_t r = 0; + shape->rank(values.size()); + for (auto v : values) + shape->dim(r++).set(v); +} + +void set_shape_vector(luci::CircleConst *const_node, const ShapeI32 &values) +{ + const_node->rank(1); + const_node->dim(0).set(values.size()); + const_node->shape_status(luci::ShapeStatus::VALID); + const_node->dtype(loco::DataType::S32); + const_node->size<loco::DataType::S32>(values.size()); + uint32_t idx = 0; + for (auto val : values) + const_node->at<loco::DataType::S32>(idx++) = val; +} + +uint32_t num_elements(const ShapeU32 shape) +{ + uint32_t result = 1; + for (auto val : shape) + result = result * val; + return result; +} + +} // namespace test +} // namespace luci diff --git a/compiler/luci/profile/CMakeLists.txt b/compiler/luci/profile/CMakeLists.txt new file mode 100644 index 000000000..f2c6665da --- /dev/null +++ b/compiler/luci/profile/CMakeLists.txt @@ -0,0 +1,22 @@ +file(GLOB_RECURSE SOURCES "src/*.cpp") +file(GLOB_RECURSE TESTS "src/*.test.cpp") +list(REMOVE_ITEM SOURCES ${TESTS}) + +add_library(luci_profile SHARED ${SOURCES}) +target_include_directories(luci_profile PRIVATE src) +target_include_directories(luci_profile PUBLIC include) +target_link_libraries(luci_profile PUBLIC loco) +target_link_libraries(luci_profile PUBLIC luci_lang) + +install(TARGETS luci_profile DESTINATION lib) + +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +nnas_find_package(GTest REQUIRED) + +GTest_AddTest(luci_profile_test ${TESTS}) +target_include_directories(luci_profile_test PRIVATE src) +target_link_libraries(luci_profile_test luci_lang) +target_link_libraries(luci_profile_test luci_profile) diff --git a/compiler/luci/profile/README.md b/compiler/luci/profile/README.md new file mode 100644 index 000000000..577e60a7c --- /dev/null +++ b/compiler/luci/profile/README.md @@ -0,0 +1,119 @@ +# luci-profile + +`luci-profile` provides profiling related items. + +## CircleNodeOrigin + +`CircleNodeOrigin` allow us know where some node is originated from. + +Let's assume following graph transformations are done. + +``` + | | | + [node1] --------+ | | +(id = 1) | | | + | +--------> [node5] ----------------> [node6] + | | (origin = [1,2]) (origin = [1,2]) + [node2] --------+ | | +(id = 2) | | + | | | + [node3] -----------------> [node3] --------+-------> [node3] +(id = 3) (origin = [3]) | (origin = [3,4]) + | | | | + [node4] -----------------> [node4] --------+ | +(id = 4) (origin = [4]) | + | | | + +<Circle1> -- optimizer --> <circle2> -- quantizer --> <circle3> +``` + +The most important purpose of using `CircleNodeOrigin` is preserving origin information. +Following changes show how origin information is preserved even after graph is transformed. + +- `node3` + - `node4` is absorbed to **existing** `node3`. + - origin of `node4` is absorbed to origin of `node3`. +- `node5` + - `node1` and `node2` are fused to **newly created** `node5`. + - origin of `node1` and `node2` are inherited to origin of `node4`. +- `node6` + - `node5` is **replaced with newly created** `node6`. + - origin of `node5` is copied to origin of `node6`. + +**Therefore, when using `CircleNodeOrigin`, please aware of the most important principle. "Preserve origin information"** + +Next items are about implementation details to store the origin information. + +### Source Table + +Source table includes a set of id and name of origin node. + +#### Binary format + +``` +[ entry_number : uint32_t ] +[ id : uint32_t ][ length : uint32_t ][ data : char * length ] * entry_number +``` +- entry_number : The number of entries + - Each entry consists of id, length, and data. +- id : ID of origin node +- length : Length of data +- data : Name of origin node **(null-terminated string)** + +#### In-memory format +```cpp +// size = entry_number +std::map<uint32_t /* id */, std::string /* name */> +``` + +#### Example + +Following example means "Name of origin 1 is node1". + +``` +[Binary Format] + 0x01 00 00 00 0x01 00 00 00 0x06 00 00 00 0x6e 0x6f 0x64 0x65 0x31 00 + ------------- ------------- ------------- ---- ---- ---- ---- ---- ---- +entry_number=1 id=1 length=6 'n' 'o' 'd' 'e' '1' '\0' +``` +```cpp +[In-memory Format] +std::map<uint32_t, std::string>({1, "node1"}); +``` + +### Op Table + +Op table includes a set of id of operation and id(s) of operation's origin nodes. + +#### Binary format + +Op table is stored in circle file as binary with following format. +``` +[ entry_number : uint32_t ] +[ id : uint32_t ][ node_num : uint32_t ][ node_ids : uint32_t * node_num ] * entry_number +``` +- entry_number : The number of entries + - Each entry consists of id, node_num, and node_ids. +- id : ID of operation in circle model file +- node_num : The number of operation's origin nodes +- node_ids : Set of IDs of origin nodes + +#### In-memory format +```cpp +std::map<uint32_t /* id */, std::set<uint32_t> /* node_ids */> +``` + +#### Example + +Following example means "Operation 5 is originated from origin 1 and origin 2". + +``` +[Binary Format] + 0x01 00 00 00 0x05 00 00 00 0x02 00 00 00 0x01 00 00 00 0x02 00 00 00 + ------------- ------------- ------------- --------------------------- +entry_number=1 id=5 node_num=2 node_ids : 1, 2 +``` +```cpp +[In-memory Format] +std::map<uint32_t, std::set<uint32_t>>({5, std::set{1, 2}}); +``` diff --git a/compiler/luci/pass/src/FuseActivationFunctionPassInternal.h b/compiler/luci/profile/include/luci/Profile/CircleNodeID.h index 0cfb9d507..165866bcf 100644 --- a/compiler/luci/pass/src/FuseActivationFunctionPassInternal.h +++ b/compiler/luci/profile/include/luci/Profile/CircleNodeID.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,18 +14,22 @@ * limitations under the License. */ -#ifndef __LUCI_CIRCLE_FUSE_ACTIVATION_FUNCTION_PASS_INTERNAL_H__ -#define __LUCI_CIRCLE_FUSE_ACTIVATION_FUNCTION_PASS_INTERNAL_H__ +#ifndef __LUCI_PROFILE_CIRCLE_NODE_ID_H__ +#define __LUCI_PROFILE_CIRCLE_NODE_ID_H__ -#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNode.h> namespace luci { -// Fuse activation function with preceding Op -/// @return true if success -bool fuse_activation_function(luci::CircleNode *node); +using CircleNodeID = uint32_t; + +bool has_node_id(const luci::CircleNode *circle_node); + +void set_node_id(luci::CircleNode *circle_node, CircleNodeID id); + +CircleNodeID get_node_id(const luci::CircleNode *circle_node); } // namespace luci -#endif // __LUCI_CIRCLE_FUSE_ACTIVATION_FUNCTION_PASS_INTERNAL_H__ +#endif // __LUCI_PROFILE_CIRCLE_NODE_ID_H__ diff --git a/compiler/luci/profile/include/luci/Profile/CircleNodeOrigin.h b/compiler/luci/profile/include/luci/Profile/CircleNodeOrigin.h new file mode 100644 index 000000000..2d6558c92 --- /dev/null +++ b/compiler/luci/profile/include/luci/Profile/CircleNodeOrigin.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PROFILE_CIRCLE_NODE_ORIGIN_H__ +#define __LUCI_PROFILE_CIRCLE_NODE_ORIGIN_H__ + +#include "CircleNodeID.h" + +#include <luci/IR/CircleNode.h> + +#include <set> + +namespace luci +{ + +class CircleNodeOrigin +{ +protected: + struct Source + { + public: + std::string name(void) const { return _name; } + void name(const std::string &name) { _name = name; } + + uint32_t id(void) const { return _id; } + void id(const uint32_t id) { _id = id; } + + private: + std::string _name; + uint32_t _id = 0; + }; + +public: + virtual std::set<const Source *> sources(void) const = 0; +}; + +std::shared_ptr<CircleNodeOrigin> single_origin(uint32_t id, const std::string &name); + +std::shared_ptr<CircleNodeOrigin> +composite_origin(const std::initializer_list<std::shared_ptr<CircleNodeOrigin>> origins); + +std::shared_ptr<CircleNodeOrigin> +composite_origin(const std::vector<std::shared_ptr<CircleNodeOrigin>> &origins); + +} // namespace luci + +namespace luci +{ + +bool has_origin(const luci::CircleNode *circle_node); + +void add_origin(luci::CircleNode *circle_node, const std::shared_ptr<CircleNodeOrigin> origin); + +// NOTE When circle_node does not have origin, nullptr is returned +const std::shared_ptr<luci::CircleNodeOrigin> get_origin(const luci::CircleNode *circle_node); + +} // namespace luci + +#endif // __LUCI_PROFILE_CIRCLE_NODE_ORIGIN_H__ diff --git a/compiler/luci/profile/src/CircleNodeID.cpp b/compiler/luci/profile/src/CircleNodeID.cpp new file mode 100644 index 000000000..750b36cae --- /dev/null +++ b/compiler/luci/profile/src/CircleNodeID.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Profile/CircleNodeID.h" + +#include <loco.h> + +#include <stdexcept> + +namespace +{ + +/** + * @brief Set annotation for circle node id + * @note Once CircleNodeID is annotated, it should not be changed. + * If CircleNodeID is needed to be changed, create new CircleNodeID. + */ +class CircleNodeIDAnnotation final : public loco::NodeAnnotation +{ +public: + CircleNodeIDAnnotation() = delete; + + CircleNodeIDAnnotation(luci::CircleNodeID node_id) : _node_id{node_id} + { + // Do nothing + } + +public: + luci::CircleNodeID node_id(void) const { return _node_id; } + // No setter + +private: + luci::CircleNodeID _node_id; +}; + +} // namespace + +namespace luci +{ + +bool has_node_id(const luci::CircleNode *circle_node) +{ + return circle_node->annot<CircleNodeIDAnnotation>() != nullptr; +} + +void set_node_id(luci::CircleNode *circle_node, luci::CircleNodeID id) +{ + circle_node->annot<CircleNodeIDAnnotation>(nullptr); + circle_node->annot(std::make_unique<CircleNodeIDAnnotation>(id)); +} + +luci::CircleNodeID get_node_id(const luci::CircleNode *circle_node) +{ + if (!has_node_id(circle_node)) + throw std::runtime_error("Cannot find CircleNodeID"); + + return circle_node->annot<CircleNodeIDAnnotation>()->node_id(); +} + +} // namespace luci diff --git a/compiler/luci/profile/src/CircleNodeID.test.cpp b/compiler/luci/profile/src/CircleNodeID.test.cpp new file mode 100644 index 000000000..d80c09b2c --- /dev/null +++ b/compiler/luci/profile/src/CircleNodeID.test.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Profile/CircleNodeID.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +TEST(LuciCircleNodeID, simple_circle_node_id) +{ + auto g = loco::make_graph(); + auto add = g->nodes()->create<luci::CircleAdd>(); + + ASSERT_FALSE(has_node_id(add)); + + set_node_id(add, 3); + + ASSERT_TRUE(has_node_id(add)); + ASSERT_EQ(3, get_node_id(add)); +} + +TEST(LuciCircleNodeID, simple_circle_node_id_NEG) +{ + auto g = loco::make_graph(); + auto add = g->nodes()->create<luci::CircleAdd>(); + + ASSERT_FALSE(has_node_id(add)); + + ASSERT_ANY_THROW(get_node_id(add)); +} diff --git a/compiler/luci/profile/src/CircleNodeOrigin.cpp b/compiler/luci/profile/src/CircleNodeOrigin.cpp new file mode 100644 index 000000000..0a731a9ad --- /dev/null +++ b/compiler/luci/profile/src/CircleNodeOrigin.cpp @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Profile/CircleNodeOrigin.h" + +#include <loco.h> + +#include <cassert> +#include <vector> + +namespace +{ + +/** + * @brief Set annotation for recording origin information + * @note Once CircleNodeOrigin is annotated, it should not be changed. + * If CircleNodeOrigin is needed to be changed, create new CircleNodeOrigin. + */ +class CircleNodeOriginAnnotation final : public loco::NodeAnnotation +{ +public: + CircleNodeOriginAnnotation() = delete; + + CircleNodeOriginAnnotation(const std::shared_ptr<luci::CircleNodeOrigin> origin) : _origin(origin) + { + // Do nothing + } + +public: + const std::shared_ptr<luci::CircleNodeOrigin> origin(void) const { return _origin; } + // No setter + +private: + const std::shared_ptr<luci::CircleNodeOrigin> _origin; +}; + +} // namespace + +namespace +{ + +class SingleOrigin final : public luci::CircleNodeOrigin +{ +public: + SingleOrigin() = delete; + + SingleOrigin(uint32_t id, const std::string &name) + { + _source.id(id); + _source.name(name); + } + +public: + std::set<const Source *> sources(void) const final + { + std::set<const Source *> res; + res.emplace(&_source); + return res; + } + +private: + Source _source; +}; + +class CompositeOrigin final : public luci::CircleNodeOrigin +{ +public: + CompositeOrigin() = delete; + + template <typename T> CompositeOrigin(T origins) + { + if (origins.size() == 0) + throw std::invalid_argument("No origins provided"); + + for (auto &origin : origins) + { + if (origin != nullptr) + _origins.emplace_back(origin); + } + } + +public: + std::set<const Source *> sources(void) const final + { + std::set<const Source *> res; + + for (auto &origin : _origins) + { + for (auto source : origin->sources()) + { + res.emplace(source); + } + } + + return res; + } + +private: + std::vector<std::shared_ptr<CircleNodeOrigin>> _origins; +}; + +} // namespace + +namespace luci +{ + +std::shared_ptr<CircleNodeOrigin> single_origin(uint32_t id, const std::string &name) +{ + return std::make_shared<SingleOrigin>(id, name); +} + +std::shared_ptr<CircleNodeOrigin> +composite_origin(const std::initializer_list<std::shared_ptr<CircleNodeOrigin>> origins) +{ + return std::make_shared<CompositeOrigin>(origins); +} + +std::shared_ptr<CircleNodeOrigin> +composite_origin(const std::vector<std::shared_ptr<CircleNodeOrigin>> &origins) +{ + return std::make_shared<CompositeOrigin>(origins); +} + +} // namespace luci + +namespace luci +{ + +bool has_origin(const luci::CircleNode *circle_node) +{ + return circle_node->annot<CircleNodeOriginAnnotation>() != nullptr; +} + +/** + * @brief 'origin' is added to the existing origin of circle_node. + * @note If 'origin' is nullptr, nothing is changed. + * For more detail, please refer to CompositeOrigin constructor. + */ +void add_origin(luci::CircleNode *circle_node, const std::shared_ptr<CircleNodeOrigin> origin) +{ + auto new_origin = composite_origin({get_origin(circle_node), origin}); + circle_node->annot<CircleNodeOriginAnnotation>(nullptr); + circle_node->annot(std::make_unique<CircleNodeOriginAnnotation>(new_origin)); +} + +const std::shared_ptr<luci::CircleNodeOrigin> get_origin(const luci::CircleNode *circle_node) +{ + if (!has_origin(circle_node)) + return nullptr; + + assert(circle_node->annot<CircleNodeOriginAnnotation>()->origin() != nullptr); + return circle_node->annot<CircleNodeOriginAnnotation>()->origin(); +} + +} // namespace luci diff --git a/compiler/luci/profile/src/CircleNodeOrigin.test.cpp b/compiler/luci/profile/src/CircleNodeOrigin.test.cpp new file mode 100644 index 000000000..34618e1ab --- /dev/null +++ b/compiler/luci/profile/src/CircleNodeOrigin.test.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Profile/CircleNodeID.h" +#include "luci/Profile/CircleNodeOrigin.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +TEST(LuciCircleNodeOrigin, simple_single_origin) +{ + auto g = loco::make_graph(); + auto add = g->nodes()->create<luci::CircleAdd>(); + + ASSERT_FALSE(has_origin(add)); + + auto origin = luci::single_origin(3, "add"); + add_origin(add, origin); + + ASSERT_TRUE(has_origin(add)); + + auto sources = get_origin(add)->sources(); + ASSERT_EQ(1, sources.size()); + for (auto source : sources) + { + ASSERT_EQ(3, source->id()); + ASSERT_EQ(0, source->name().compare("add")); + } +} + +TEST(LuciCircleNodeOrigin, simple_composite_origin_with_initializer) +{ + auto g = loco::make_graph(); + auto mul = g->nodes()->create<luci::CircleMul>(); + + ASSERT_FALSE(has_origin(mul)); + + auto origin = + luci::composite_origin({luci::single_origin(3, "add"), luci::single_origin(7, "sub")}); + add_origin(mul, origin); + + ASSERT_TRUE(has_origin(mul)); + + bool add_origin_passed = false; + bool sub_origin_passed = false; + auto sources = get_origin(mul)->sources(); + ASSERT_EQ(2, sources.size()); + for (auto source : sources) + { + if (source->id() == 3 && source->name().compare("add") == 0) + add_origin_passed = true; + if (source->id() == 7 && source->name().compare("sub") == 0) + sub_origin_passed = true; + } + + ASSERT_EQ(true, add_origin_passed); + ASSERT_EQ(true, sub_origin_passed); +} + +TEST(LuciCircleNodeOrigin, simple_composite_origin_with_vector) +{ + auto g = loco::make_graph(); + auto mul = g->nodes()->create<luci::CircleMul>(); + + ASSERT_FALSE(has_origin(mul)); + + std::vector<std::shared_ptr<luci::CircleNodeOrigin>> vec; + vec.push_back(luci::single_origin(3, "add")); + vec.push_back(luci::single_origin(7, "sub")); + auto origin = luci::composite_origin(vec); + add_origin(mul, origin); + + ASSERT_TRUE(has_origin(mul)); + + bool add_origin_passed = false; + bool sub_origin_passed = false; + auto sources = get_origin(mul)->sources(); + ASSERT_EQ(2, sources.size()); + for (auto source : sources) + { + if (source->id() == 3 && source->name().compare("add") == 0) + add_origin_passed = true; + if (source->id() == 7 && source->name().compare("sub") == 0) + sub_origin_passed = true; + } + + ASSERT_EQ(true, add_origin_passed); + ASSERT_EQ(true, sub_origin_passed); +} + +TEST(LuciCircleNodeOrigin, composite_origin_empty_ctor_NEG) +{ + ASSERT_ANY_THROW(luci::composite_origin({})); +} diff --git a/compiler/luci/service/CMakeLists.txt b/compiler/luci/service/CMakeLists.txt index 9f50c9c4f..1c78031ab 100644 --- a/compiler/luci/service/CMakeLists.txt +++ b/compiler/luci/service/CMakeLists.txt @@ -22,4 +22,5 @@ nnas_find_package(GTest REQUIRED) GTest_AddTest(luci_service_test ${TESTS}) target_include_directories(luci_service_test PRIVATE src) target_link_libraries(luci_service_test luci_service) +target_link_libraries(luci_service_test luci_testhelper) target_link_libraries(luci_service_test oops) diff --git a/compiler/luci/service/include/luci/Service/CircleNodeClone.h b/compiler/luci/service/include/luci/Service/CircleNodeClone.h new file mode 100644 index 000000000..2429997cc --- /dev/null +++ b/compiler/luci/service/include/luci/Service/CircleNodeClone.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_CIRCLE_NODE_CLONE__ +#define __LUCI_CIRCLE_NODE_CLONE__ + +#include <luci/IR/CircleNodes.h> + +#include <loco/IR/Graph.h> + +namespace luci +{ + +/** + * @brief Copy common attributes of CircleNode from src to dst. + */ +void copy_common_attributes(const luci::CircleNode *src, luci::CircleNode *dst); + +/** + * @brief Return a new cloned CircleNode object with same attributes value of node to graph. + * @note Will return nullptr if clone has failed + */ +CircleNode *clone_node(const CircleNode *node, loco::Graph *graph); + +} // namespace luci + +#endif // __LUCI_CIRCLE_NODE_CLONE__ diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h index c301db5f4..60bc16e48 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h @@ -17,29 +17,15 @@ #ifndef __LUCI_CIRCLE_SHAPE_INFERENCE_H__ #define __LUCI_CIRCLE_SHAPE_INFERENCE_H__ -#include "ShapeDescription.h" - #include <loco/IR/Nodes.h> #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> -#include <luci/Service/CircleShapeInferenceHelper.h> +#include <luci/Service/CircleShapeInferenceRule.h> namespace luci { -/** - * @brief Get the shape of each node as a node annotation - * - * HOW TO USE - * - * ShapeInference::get(g->nodes()->at(..)); - */ -struct ShapeInference -{ - static ShapeDescription get(loco::Node *node); -}; - namespace sinf // namespace for Shape Inference { @@ -52,7 +38,12 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape> { public: // TODO Remove this when all of visit function is implemented - loco::TensorShape visit(const luci::CircleNode *node) final { return sinf::circle_shape(node); } + loco::TensorShape visit(const luci::CircleNode *node) final + { + loco::NodeShape shape; + luci::CircleShapeInferenceRule().infer(node, shape); + return shape.as<loco::TensorShape>(); + } // loco::TensorShape visit(const luci::CircleAbs *node) final; // loco::TensorShape visit(const luci::CircleAdd *node) final; @@ -77,6 +68,7 @@ public: // loco::TensorShape visit(const luci::CircleEqual *node) final; // loco::TensorShape visit(const luci::CircleExp *node) final; // loco::TensorShape visit(const luci::CircleExpandDims *node) final; + // loco::TensorShape visit(const luci::CircleFakeQuant *node) final; // loco::TensorShape visit(const luci::CircleFill *node) final; // loco::TensorShape visit(const luci::CircleFloor *node) final; // loco::TensorShape visit(const luci::CircleFloorDiv *node) final; @@ -106,10 +98,12 @@ public: // loco::TensorShape visit(const luci::CircleMean *node) final; // loco::TensorShape visit(const luci::CircleMinimum *node) final; // loco::TensorShape visit(const luci::CircleMirrorPad *node) final; + // loco::TensorShape visit(const luci::CircleMul *node) final; // loco::TensorShape visit(const luci::CircleNeg *node) final; // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4 *node) final; // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5 *node) final; // loco::TensorShape visit(const luci::CircleNotEqual *node) final; + // loco::TensorShape visit(const luci::CircleOneHot *node) final; // loco::TensorShape visit(const luci::CirclePack *node) final; // loco::TensorShape visit(const luci::CirclePad *node) final; // loco::TensorShape visit(const luci::CirclePadV2 *node) final; @@ -117,8 +111,6 @@ public: // loco::TensorShape visit(const luci::CirclePRelu *node) final; // loco::TensorShape visit(const luci::CircleRange *node) final; // loco::TensorShape visit(const luci::CircleRank *node) final; - // loco::TensorShape visit(const luci::CircleMul *node) final; - // loco::TensorShape visit(const luci::CircleOneHot *node) final; // loco::TensorShape visit(const luci::CircleReduceAny *node) final; // loco::TensorShape visit(const luci::CircleReduceMax *node) final; // loco::TensorShape visit(const luci::CircleReduceMin *node) final; @@ -171,14 +163,14 @@ public: // loco::TensorShape visit(const luci::CircleInstanceNorm *node) final; // Virtual + // loco::TensorShape visit(const luci::CircleCustomOut *node) final; + loco::TensorShape visit(const luci::CircleIfOut *node) final; // loco::TensorShape visit(const luci::CircleInput *node) final; + // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final; + // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5Out *node) final; // loco::TensorShape visit(const luci::CircleOutput *node) final; // loco::TensorShape visit(const luci::CircleOutputDummy *node) final; // loco::TensorShape visit(const luci::CircleOutputExclude *node) final; - // loco::TensorShape visit(const luci::CircleCustomOut *node) final; - // loco::TensorShape visit(const luci::CircleIfOut *node) final; - // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final; - // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5Out *node) final; // loco::TensorShape visit(const luci::CircleSplitOut *node) final; // loco::TensorShape visit(const luci::CircleSplitVOut *node) final; // loco::TensorShape visit(const luci::CircleTopKV2Out *node) final; diff --git a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h deleted file mode 100644 index f7ea89bb8..000000000 --- a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__ -#define __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__ - -#include <luci/IR/CircleNodes.h> -#include <luci/IR/CircleNodeVisitor.h> -#include <luci/IR/CircleShapeSignature.h> -#include <luci/Service/CircleShapeSignatureInferenceHelper.h> - -namespace luci -{ - -namespace ssinf // namespace for Shape Signature Inference -{ - -struct Rule -{ - bool infer(const luci::CircleNode *, ShapeSignature &) const; -}; - -class Algorithm final : public luci::CircleNodeVisitor<ShapeSignature> -{ -public: - // TODO Remove this when visit function is implemented for all the operations. - ShapeSignature visit(const luci::CircleNode *node) final { return node->shape_signature(); } - - // ShapeSignature visit(const luci::CircleAbs *node) final; - // ShapeSignature visit(const luci::CircleAdd *node) final; - // ShapeSignature visit(const luci::CircleAddN *node) final; - // ShapeSignature visit(const luci::CircleArgMax *node) final; - // ShapeSignature visit(const luci::CircleArgMin *node) final; - // ShapeSignature visit(const luci::CircleAveragePool2D *node) final; - // ShapeSignature visit(const luci::CircleBatchMatMul *node) final; - // ShapeSignature visit(const luci::CircleBatchToSpaceND *node) final; - // ShapeSignature visit(const luci::CircleCast *node) final; - // ShapeSignature visit(const luci::CircleCeil *node) final; - // ShapeSignature visit(const luci::CircleConcatenation *node) final; - // ShapeSignature visit(const luci::CircleConst *node) final; - // ShapeSignature visit(const luci::CircleConv2D *node) final; - // ShapeSignature visit(const luci::CircleCos *node) final; - // ShapeSignature visit(const luci::CircleCustom *node) final; - // ShapeSignature visit(const luci::CircleDepthToSpace *node) final; - // ShapeSignature visit(const luci::CircleDepthwiseConv2D *node) final; - // ShapeSignature visit(const luci::CircleDequantize *node) final; - // ShapeSignature visit(const luci::CircleDiv *node) final; - // ShapeSignature visit(const luci::CircleElu *node) final; - // ShapeSignature visit(const luci::CircleEqual *node) final; - // ShapeSignature visit(const luci::CircleExp *node) final; - // ShapeSignature visit(const luci::CircleExpandDims *node) final; - // ShapeSignature visit(const luci::CircleFill *node) final; - // ShapeSignature visit(const luci::CircleFloor *node) final; - // ShapeSignature visit(const luci::CircleFloorDiv *node) final; - // ShapeSignature visit(const luci::CircleFloorMod *node) final; - // ShapeSignature visit(const luci::CircleFullyConnected *node) final; - // ShapeSignature visit(const luci::CircleGather *node) final; - // ShapeSignature visit(const luci::CircleGatherNd *node) final; - // ShapeSignature visit(const luci::CircleGreater *node) final; - // ShapeSignature visit(const luci::CircleGreaterEqual *node) final; - // ShapeSignature visit(const luci::CircleIf *node) final; - // ShapeSignature visit(const luci::CircleL2Normalize *node) final; - // ShapeSignature visit(const luci::CircleL2Pool2D *node) final; - // ShapeSignature visit(const luci::CircleLeakyRelu *node) final; - // ShapeSignature visit(const luci::CircleLess *node) final; - // ShapeSignature visit(const luci::CircleLessEqual *node) final; - // ShapeSignature visit(const luci::CircleLocalResponseNormalization *node) final; - // ShapeSignature visit(const luci::CircleLog *node) final; - // ShapeSignature visit(const luci::CircleLogicalAnd *node) final; - // ShapeSignature visit(const luci::CircleLogicalNot *node) final; - // ShapeSignature visit(const luci::CircleLogicalOr *node) final; - // ShapeSignature visit(const luci::CircleLogistic *node) final; - // ShapeSignature visit(const luci::CircleLogSoftmax *node) final; - // ShapeSignature visit(const luci::CircleMatrixDiag *node) final; - // ShapeSignature visit(const luci::CircleMatrixSetDiag *node) final; - // ShapeSignature visit(const luci::CircleMaximum *node) final; - // ShapeSignature visit(const luci::CircleMaxPool2D *node) final; - ShapeSignature visit(const luci::CircleMean *node) final; - // ShapeSignature visit(const luci::CircleMinimum *node) final; - // ShapeSignature visit(const luci::CircleMirrorPad *node) final; - // ShapeSignature visit(const luci::CircleNeg *node) final; - // ShapeSignature visit(const luci::CircleNonMaxSuppressionV4 *node) final; - // ShapeSignature visit(const luci::CircleNonMaxSuppressionV5 *node) final; - // ShapeSignature visit(const luci::CircleNotEqual *node) final; - // ShapeSignature visit(const luci::CirclePack *node) final; - // ShapeSignature visit(const luci::CirclePad *node) final; - // ShapeSignature visit(const luci::CirclePadV2 *node) final; - // ShapeSignature visit(const luci::CirclePow *node) final; - // ShapeSignature visit(const luci::CirclePRelu *node) final; - // ShapeSignature visit(const luci::CircleRange *node) final; - // ShapeSignature visit(const luci::CircleRank *node) final; - // ShapeSignature visit(const luci::CircleMul *node) final; - // ShapeSignature visit(const luci::CircleOneHot *node) final; - ShapeSignature visit(const luci::CircleReduceAny *node) final; - ShapeSignature visit(const luci::CircleReduceMax *node) final; - ShapeSignature visit(const luci::CircleReduceMin *node) final; - ShapeSignature visit(const luci::CircleReduceProd *node) final; - ShapeSignature visit(const luci::CircleRelu *node) final; - ShapeSignature visit(const luci::CircleRelu6 *node) final; - ShapeSignature visit(const luci::CircleReluN1To1 *node) final; - // ShapeSignature visit(const luci::CircleReshape *node) final; - // ShapeSignature visit(const luci::CircleResizeBilinear *node) final; - // ShapeSignature visit(const luci::CircleResizeNearestNeighbor *node) final; - // ShapeSignature visit(const luci::CircleReverseSequence *node) final; - // ShapeSignature visit(const luci::CircleReverseV2 *node) final; - // ShapeSignature visit(const luci::CircleRound *node) final; - // ShapeSignature visit(const luci::CircleRsqrt *node) final; - // ShapeSignature visit(const luci::CircleScatterNd *node) final; - // ShapeSignature visit(const luci::CircleSegmentSum *node) final; - // ShapeSignature visit(const luci::CircleSelect *node) final; - // ShapeSignature visit(const luci::CircleSelectV2 *node) final; - // ShapeSignature visit(const luci::CircleShape *node) final; - // ShapeSignature visit(const luci::CircleSin *node) final; - // ShapeSignature visit(const luci::CircleSlice *node) final; - // ShapeSignature visit(const luci::CircleSoftmax *node) final; - // ShapeSignature visit(const luci::CircleSpaceToBatchND *node) final; - // ShapeSignature visit(const luci::CircleSpaceToDepth *node) final; - // ShapeSignature visit(const luci::CircleSparseToDense *node) final; - // ShapeSignature visit(const luci::CircleSplit *node) final; - // ShapeSignature visit(const luci::CircleSplitV *node) final; - // ShapeSignature visit(const luci::CircleSqrt *node) final; - // ShapeSignature visit(const luci::CircleSquare *node) final; - // ShapeSignature visit(const luci::CircleSquaredDifference *node) final; - // ShapeSignature visit(const luci::CircleSqueeze *node) final; - // ShapeSignature visit(const luci::CircleStridedSlice *node) final; - // ShapeSignature visit(const luci::CircleSub *node) final; - ShapeSignature visit(const luci::CircleSum *node) final; - // ShapeSignature visit(const luci::CircleTanh *node) final; - // ShapeSignature visit(const luci::CircleTile *node) final; - // ShapeSignature visit(const luci::CircleTopKV2 *node) final; - // ShapeSignature visit(const luci::CircleTranspose *node) final; - // ShapeSignature visit(const luci::CircleTransposeConv *node) final; - // ShapeSignature visit(const luci::CircleUnidirectionalSequenceLSTM *node) final; - // ShapeSignature visit(const luci::CircleUnique *node) final; - // ShapeSignature visit(const luci::CircleUnpack *node) final; - // ShapeSignature visit(const luci::CircleWhere *node) final ; - // ShapeSignature visit(const luci::CircleWhile *node) final; - // ShapeSignature visit(const luci::CircleZerosLike *node) final; - - // Circle Only - // ShapeSignature visit(const luci::CircleBCQFullyConnected *node) final; - // ShapeSignature visit(const luci::CircleBCQGather *node) final; - // ShapeSignature visit(const luci::CircleInstanceNorm *node) final; - - // Virtual - ShapeSignature visit(const luci::CircleInput *node) final; - ShapeSignature visit(const luci::CircleOutput *node) final; - ShapeSignature visit(const luci::CircleOutputDummy *node) final; - ShapeSignature visit(const luci::CircleOutputExclude *node) final; - // ShapeSignature visit(const luci::CircleCustomOut *node) final; - // ShapeSignature visit(const luci::CircleIfOut *node) final; - // ShapeSignature visit(const luci::CircleNonMaxSuppressionV4Out *node) final; - // ShapeSignature visit(const luci::CircleNonMaxSuppressionV5Out *node) final; - // ShapeSignature visit(const luci::CircleSplitOut *node) final; - // ShapeSignature visit(const luci::CircleSplitVOut *node) final; - // ShapeSignature visit(const luci::CircleTopKV2Out *node) final; - // ShapeSignature visit(const luci::CircleUniqueOut *node) final; - // ShapeSignature visit(const luci::CircleUnpackOut *node) final; - // ShapeSignature visit(const luci::CircleWhileOut *node) final; -}; - -} // namespace ssinf - -} // namespace luci - -#endif // __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__ diff --git a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h deleted file mode 100644 index fb5b3b302..000000000 --- a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__ -#define __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__ - -#include <luci/IR/CircleNodes.h> -#include <luci/IR/CircleShapeSignature.h> - -namespace luci -{ - -namespace ssinf // Namespace for Shape Signature Inference -{ - -// Return empty signature if all of dimensions are known. -// If at least one of dimensions is unknown, return signature without change. -ShapeSignature legalized_signature(const luci::ShapeSignature &signature); - -// Return reduced input_signature with indices and keep_dims. -// - indices : reduction index -// - keep_dims : If true, rank is not changed. If false, rank is reduced along indices. -ShapeSignature reduced_signature(const loco::Node *node, const loco::Node *indices, bool keep_dims); - -// Return signature of index-th argument of node. -ShapeSignature input_arg_signature(const luci::CircleNode *node, uint32_t index); - -} // namespace ssinf - -} // namespace luci - -#endif // __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__ diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInference.h b/compiler/luci/service/include/luci/Service/CircleTypeInference.h index 342214887..8eef469ac 100644 --- a/compiler/luci/service/include/luci/Service/CircleTypeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleTypeInference.h @@ -23,24 +23,11 @@ #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> -#include <luci/Service/CircleTypeInferenceHelper.h> +#include <luci/Service/CircleTypeInferenceRule.h> namespace luci { -/** - * @brief Get the type of each node as NodeAnnotation - * - * HOW TO USE - * - * TypeInference::get(g->nodes()->at(0)); - * TypeInference::get(g->nodes()->at(...)); - */ -struct TypeInference -{ - static circle::TensorType get(loco::Node *node); -}; - namespace tinf // namespace for Type Inference { @@ -53,7 +40,12 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::DataType> { public: // TODO Remove this when all of visit function is implemented - loco::DataType visit(const luci::CircleNode *node) final { return node->dtype(); } + loco::DataType visit(const luci::CircleNode *node) final + { + loco::DataType dtype; + luci::CircleTypeInferenceRule().infer(node, dtype); + return dtype; + } // loco::DataType visit(const luci::CircleAbs *node) final; // loco::DataType visit(const luci::CircleAdd *node) final; @@ -78,6 +70,7 @@ public: // loco::DataType visit(const luci::CircleEqual *node) final; // loco::DataType visit(const luci::CircleExp *node) final; // loco::DataType visit(const luci::CircleExpandDims *node) final; + // loco::DataType visit(const luci::CircleFakeQuant *node) final; // loco::DataType visit(const luci::CircleFill *node) final; // loco::DataType visit(const luci::CircleFloor *node) final; // loco::DataType visit(const luci::CircleFloorDiv *node) final; @@ -177,7 +170,7 @@ public: // loco::DataType visit(const luci::CircleOutputDummy *node) final; // loco::DataType visit(const luci::CircleOutputExclude *node) final; // loco::DataType visit(const luci::CircleCustomOut *node) final; - // loco::DataType visit(const luci::CircleIfOut *node) final; + loco::DataType visit(const luci::CircleIfOut *node) final; // loco::DataType visit(const luci::CircleNonMaxSuppressionV4Out *node) final; // loco::DataType visit(const luci::CircleNonMaxSuppressionV5Out *node) final; // loco::DataType visit(const luci::CircleSplitOut *node) final; diff --git a/compiler/luci/service/include/luci/Service/Nodes/CircleConst.h b/compiler/luci/service/include/luci/Service/Nodes/CircleConst.h new file mode 100644 index 000000000..6049b4297 --- /dev/null +++ b/compiler/luci/service/include/luci/Service/Nodes/CircleConst.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_SERVICE_CIRCLE_CONST_H__ +#define __LUCI_SERVICE_CIRCLE_CONST_H__ + +#include <luci/IR/Nodes/CircleConst.h> + +namespace luci +{ + +/** + * @brief Return cloned object of CircleConst node + */ +luci::CircleConst *clone(luci::CircleConst *node); + +} // namespace luci + +#endif // __LUCI_SERVICE_CIRCLE_CONST_H__ diff --git a/compiler/luci/service/include/luci/Service/ShapeDescription.h b/compiler/luci/service/include/luci/Service/ShapeDescription.h index 4d92be13f..4671096fd 100644 --- a/compiler/luci/service/include/luci/Service/ShapeDescription.h +++ b/compiler/luci/service/include/luci/Service/ShapeDescription.h @@ -37,10 +37,6 @@ struct ShapeDescription // TODO remove these when CircleDialect is fully functioal ShapeDescription to_shape_description(const luci::CircleNode *node); ShapeDescription to_shape_description(const loco::TensorShape &shape); -ShapeDescription to_shape_description(const loco::FeatureShape &shape); -ShapeDescription to_shape_description(const loco::FilterShape &shape); -ShapeDescription to_shape_description(const loco::BiasShape &shape); -ShapeDescription to_shape_description(const loco::MatrixShape &shape); ShapeDescription to_shape_description(const loco::NodeShape &shape); template <typename Permutation> inline bool isNHWC(Permutation *perm); diff --git a/compiler/luci/service/include/luci/Service/Validate.h b/compiler/luci/service/include/luci/Service/Validate.h index 4b80d1d16..456d6e504 100644 --- a/compiler/luci/service/include/luci/Service/Validate.h +++ b/compiler/luci/service/include/luci/Service/Validate.h @@ -17,6 +17,8 @@ #ifndef __LUCI_SERVICE_VALIDATE_H__ #define __LUCI_SERVICE_VALIDATE_H__ +#include <luci/IR/Module.h> + #include <loco.h> namespace luci @@ -24,6 +26,17 @@ namespace luci bool validate(loco::Graph *); +/** + * @brief Return true if all nodes in graph have non empty name + */ +bool validate_name(loco::Graph *); + +/** + * @brief Return true if all names in the Module are unique + * @note CircleOutput may have duplicate name + */ +bool validate_unique_name(luci::Module *); + } // namespace luci #endif // __LUCI_SERVICE_VALIDATE_H__ diff --git a/compiler/luci/service/src/CircleCloneNode.h b/compiler/luci/service/src/CircleCloneNode.h new file mode 100644 index 000000000..02c7cd256 --- /dev/null +++ b/compiler/luci/service/src/CircleCloneNode.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CIRCLE_CLONE_NODE_H__ +#define __CIRCLE_CLONE_NODE_H__ + +#include <luci/IR/CircleNodes.h> + +#include <luci/IR/CircleNodeVisitor.h> + +namespace luci +{ + +class CloneNode final : public luci::CircleNodeVisitor<luci::CircleNode *> +{ +public: + CloneNode(loco::Graph *graph) : _graph(graph){}; + +public: + luci::CircleNode *visit(const luci::CircleAbs *) final; + luci::CircleNode *visit(const luci::CircleAdd *) final; + luci::CircleNode *visit(const luci::CircleAddN *) final; + luci::CircleNode *visit(const luci::CircleArgMax *) final; + luci::CircleNode *visit(const luci::CircleArgMin *) final; + luci::CircleNode *visit(const luci::CircleAveragePool2D *) final; + luci::CircleNode *visit(const luci::CircleBatchMatMul *) final; + luci::CircleNode *visit(const luci::CircleBatchToSpaceND *) final; + luci::CircleNode *visit(const luci::CircleCast *) final; + luci::CircleNode *visit(const luci::CircleCeil *) final; + luci::CircleNode *visit(const luci::CircleConcatenation *) final; + luci::CircleNode *visit(const luci::CircleConst *) final; + luci::CircleNode *visit(const luci::CircleConv2D *) final; + luci::CircleNode *visit(const luci::CircleCos *) final; + luci::CircleNode *visit(const luci::CircleCustom *) final; + luci::CircleNode *visit(const luci::CircleDepthToSpace *) final; + luci::CircleNode *visit(const luci::CircleDepthwiseConv2D *) final; + luci::CircleNode *visit(const luci::CircleDequantize *) final; + luci::CircleNode *visit(const luci::CircleDiv *) final; + luci::CircleNode *visit(const luci::CircleElu *) final; + luci::CircleNode *visit(const luci::CircleEqual *) final; + luci::CircleNode *visit(const luci::CircleExp *) final; + luci::CircleNode *visit(const luci::CircleExpandDims *) final; + luci::CircleNode *visit(const luci::CircleFakeQuant *) final; + luci::CircleNode *visit(const luci::CircleFill *) final; + luci::CircleNode *visit(const luci::CircleFloor *) final; + luci::CircleNode *visit(const luci::CircleFloorDiv *) final; + luci::CircleNode *visit(const luci::CircleFloorMod *) final; + luci::CircleNode *visit(const luci::CircleFullyConnected *) final; + luci::CircleNode *visit(const luci::CircleGather *) final; + luci::CircleNode *visit(const luci::CircleGatherNd *) final; + luci::CircleNode *visit(const luci::CircleGreater *) final; + luci::CircleNode *visit(const luci::CircleGreaterEqual *) final; + // luci::CircleNode *visit(const luci::CircleIf *) final; + luci::CircleNode *visit(const luci::CircleL2Normalize *) final; + luci::CircleNode *visit(const luci::CircleL2Pool2D *) final; + luci::CircleNode *visit(const luci::CircleLeakyRelu *) final; + luci::CircleNode *visit(const luci::CircleLess *) final; + luci::CircleNode *visit(const luci::CircleLessEqual *) final; + luci::CircleNode *visit(const luci::CircleLocalResponseNormalization *) final; + luci::CircleNode *visit(const luci::CircleLog *) final; + luci::CircleNode *visit(const luci::CircleLogicalAnd *) final; + luci::CircleNode *visit(const luci::CircleLogicalNot *) final; + luci::CircleNode *visit(const luci::CircleLogicalOr *) final; + luci::CircleNode *visit(const luci::CircleLogistic *) final; + luci::CircleNode *visit(const luci::CircleLogSoftmax *) final; + luci::CircleNode *visit(const luci::CircleMatrixDiag *) final; + luci::CircleNode *visit(const luci::CircleMatrixSetDiag *) final; + luci::CircleNode *visit(const luci::CircleMaximum *) final; + luci::CircleNode *visit(const luci::CircleMaxPool2D *) final; + luci::CircleNode *visit(const luci::CircleMean *) final; + luci::CircleNode *visit(const luci::CircleMinimum *) final; + luci::CircleNode *visit(const luci::CircleMirrorPad *) final; + luci::CircleNode *visit(const luci::CircleMul *) final; + luci::CircleNode *visit(const luci::CircleNeg *) final; + luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV4 *) final; + luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV5 *) final; + luci::CircleNode *visit(const luci::CircleNotEqual *) final; + luci::CircleNode *visit(const luci::CircleOneHot *) final; + luci::CircleNode *visit(const luci::CirclePack *) final; + luci::CircleNode *visit(const luci::CirclePad *) final; + luci::CircleNode *visit(const luci::CirclePadV2 *) final; + luci::CircleNode *visit(const luci::CirclePow *) final; + luci::CircleNode *visit(const luci::CirclePRelu *) final; + luci::CircleNode *visit(const luci::CircleRange *) final; + luci::CircleNode *visit(const luci::CircleRank *) final; + luci::CircleNode *visit(const luci::CircleReduceAny *) final; + luci::CircleNode *visit(const luci::CircleReduceMax *) final; + luci::CircleNode *visit(const luci::CircleReduceMin *) final; + luci::CircleNode *visit(const luci::CircleReduceProd *) final; + luci::CircleNode *visit(const luci::CircleRelu *) final; + luci::CircleNode *visit(const luci::CircleRelu6 *) final; + luci::CircleNode *visit(const luci::CircleReluN1To1 *) final; + luci::CircleNode *visit(const luci::CircleReshape *) final; + luci::CircleNode *visit(const luci::CircleResizeBilinear *) final; + luci::CircleNode *visit(const luci::CircleResizeNearestNeighbor *) final; + luci::CircleNode *visit(const luci::CircleReverseSequence *) final; + luci::CircleNode *visit(const luci::CircleReverseV2 *) final; + luci::CircleNode *visit(const luci::CircleRound *) final; + luci::CircleNode *visit(const luci::CircleRsqrt *) final; + luci::CircleNode *visit(const luci::CircleScatterNd *) final; + luci::CircleNode *visit(const luci::CircleSegmentSum *) final; + luci::CircleNode *visit(const luci::CircleSelect *) final; + luci::CircleNode *visit(const luci::CircleSelectV2 *) final; + luci::CircleNode *visit(const luci::CircleShape *) final; + luci::CircleNode *visit(const luci::CircleSin *) final; + luci::CircleNode *visit(const luci::CircleSlice *) final; + luci::CircleNode *visit(const luci::CircleSoftmax *) final; + luci::CircleNode *visit(const luci::CircleSpaceToBatchND *) final; + luci::CircleNode *visit(const luci::CircleSpaceToDepth *) final; + luci::CircleNode *visit(const luci::CircleSparseToDense *) final; + luci::CircleNode *visit(const luci::CircleSplit *) final; + luci::CircleNode *visit(const luci::CircleSplitV *) final; + luci::CircleNode *visit(const luci::CircleSqrt *) final; + luci::CircleNode *visit(const luci::CircleSquare *) final; + luci::CircleNode *visit(const luci::CircleSquaredDifference *) final; + luci::CircleNode *visit(const luci::CircleSqueeze *) final; + luci::CircleNode *visit(const luci::CircleStridedSlice *) final; + luci::CircleNode *visit(const luci::CircleSub *) final; + luci::CircleNode *visit(const luci::CircleSum *) final; + luci::CircleNode *visit(const luci::CircleTanh *) final; + luci::CircleNode *visit(const luci::CircleTile *) final; + luci::CircleNode *visit(const luci::CircleTopKV2 *) final; + luci::CircleNode *visit(const luci::CircleTranspose *) final; + luci::CircleNode *visit(const luci::CircleTransposeConv *) final; + luci::CircleNode *visit(const luci::CircleUnidirectionalSequenceLSTM *) final; + luci::CircleNode *visit(const luci::CircleUnique *) final; + luci::CircleNode *visit(const luci::CircleUnpack *) final; + luci::CircleNode *visit(const luci::CircleWhere *) final; + // luci::CircleNode *visit(const luci::CircleWhile *) final; + luci::CircleNode *visit(const luci::CircleZerosLike *) final; + + // Circle Only + luci::CircleNode *visit(const luci::CircleBCQFullyConnected *) final; + luci::CircleNode *visit(const luci::CircleBCQGather *) final; + luci::CircleNode *visit(const luci::CircleInstanceNorm *) final; + + // Virtual + luci::CircleNode *visit(const luci::CircleCustomOut *) final; + // luci::CircleNode *visit(const luci::CircleIfOut *) final; + // luci::CircleNode *visit(const luci::CircleInput *) final; + luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV4Out *) final; + luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV5Out *) final; + // luci::CircleNode *visit(const luci::CircleOutput *) final; + luci::CircleNode *visit(const luci::CircleOutputDummy *) final; + luci::CircleNode *visit(const luci::CircleOutputExclude *) final; + luci::CircleNode *visit(const luci::CircleSplitOut *) final; + luci::CircleNode *visit(const luci::CircleSplitVOut *) final; + luci::CircleNode *visit(const luci::CircleTopKV2Out *) final; + luci::CircleNode *visit(const luci::CircleUniqueOut *) final; + luci::CircleNode *visit(const luci::CircleUnpackOut *) final; + // luci::CircleNode *visit(const luci::CircleWhileOut *) final; + + // NOTE CircleNodeVisitor will throw if not supported here + +protected: + loco::Graph *_graph = nullptr; +}; + +} // namespace luci + +#endif // __CIRCLE_CLONE_NODE_H__ diff --git a/compiler/luci/service/src/CircleNodeClone.cpp b/compiler/luci/service/src/CircleNodeClone.cpp new file mode 100644 index 000000000..d2033dd0c --- /dev/null +++ b/compiler/luci/service/src/CircleNodeClone.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include "CircleCloneNode.h" + +#include <oops/UserExn.h> + +#include <cassert> + +namespace luci +{ + +/** + * @note Attributes of specific node type like keep_dims() of CircleSum are + * not copied. + */ +void copy_common_attributes(const luci::CircleNode *src, luci::CircleNode *dst) +{ + assert(src != nullptr); + assert(dst != nullptr); + + dst->name(src->name()); + dst->dtype(src->dtype()); + + dst->rank(src->rank()); + for (uint32_t i = 0; i < src->rank(); i++) + { + dst->dim(i) = src->dim(i); + } + dst->shape_status(src->shape_status()); + + // quantparam + const auto *quantparam = src->quantparam(); + if (quantparam != nullptr) + { + auto qparam = std::make_unique<luci::CircleQuantParam>(); + qparam->scale = quantparam->scale; + qparam->zerop = quantparam->zerop; + qparam->min = quantparam->min; + qparam->max = quantparam->max; + qparam->quantized_dimension = quantparam->quantized_dimension; + + dst->quantparam(std::move(qparam)); + } + + // sparsity + const auto *sparsity = src->sparsityparam(); + if (sparsity != nullptr) + { + auto sparam = std::make_unique<luci::SparsityParam>(); + sparam->traversal_order = sparsity->traversal_order; + sparam->block_map = sparsity->block_map; + sparam->dim_metadata = sparsity->dim_metadata; + + dst->sparsityparam(std::move(sparam)); + } + + // op version + dst->op_version(src->op_version()); +} + +/** + * @note Each visit implementation must copy node specific attributes. + */ +luci::CircleNode *clone_node(const luci::CircleNode *node, loco::Graph *graph) +{ + if (node == nullptr || graph == nullptr) + return nullptr; + + CloneNode cn(graph); + auto cloned = node->accept(&cn); + if (cloned != nullptr) + copy_common_attributes(node, cloned); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/CircleNodeClone.test.cpp b/compiler/luci/service/src/CircleNodeClone.test.cpp new file mode 100644 index 000000000..5908eeb82 --- /dev/null +++ b/compiler/luci/service/src/CircleNodeClone.test.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +// NOTE any node will do for testing +#include <luci/IR/Nodes/CircleAdd.h> + +#include <gtest/gtest.h> + +namespace +{ + +luci::CircleAdd *build_simple_add_graph(loco::Graph *g) +{ + auto node = g->nodes()->create<luci::CircleAdd>(); + + node->name("name"); + node->dtype(loco::DataType::FLOAT32); + node->rank(1); + node->dim(0).set(3); + node->shape_status(luci::ShapeStatus::VALID); + node->fusedActivationFunction(luci::FusedActFunc::NONE); + + auto qparam = std::make_unique<luci::CircleQuantParam>(); + qparam->scale = {1.0}; + qparam->zerop = {0}; + qparam->min = {0.0}; + qparam->max = {1.0}; + qparam->quantized_dimension = 0; + node->quantparam(std::move(qparam)); + + auto sparam = std::make_unique<luci::SparsityParam>(); + sparam->traversal_order = {0}; + sparam->block_map = {0}; + sparam->dim_metadata = {luci::DimMetaData(luci::DimensionType::DENSE, 1)}; + node->sparsityparam(std::move(sparam)); + + node->op_version(2); + + return node; +} + +} // namespace + +TEST(CircleNodeCloneTest, copy_attribites) +{ + auto g = loco::make_graph(); + auto node = build_simple_add_graph(g.get()); + + auto copy = g->nodes()->create<luci::CircleAdd>(); + luci::copy_common_attributes(node, copy); + + ASSERT_EQ(node->name(), copy->name()); + ASSERT_EQ(node->dtype(), copy->dtype()); + ASSERT_EQ(node->rank(), copy->rank()); + ASSERT_EQ(node->shape_status(), copy->shape_status()); + + const auto *qparam_node = node->quantparam(); + const auto *qparam_copy = copy->quantparam(); + ASSERT_EQ(qparam_node->scale, qparam_copy->scale); + + const auto *sparsity_node = node->sparsityparam(); + const auto *sparsity_copy = copy->sparsityparam(); + ASSERT_EQ(sparsity_node->traversal_order, sparsity_copy->traversal_order); + + ASSERT_EQ(node->op_version(), copy->op_version()); +} + +TEST(CircleNodeCloneTest, clone_add_node) +{ + auto g = loco::make_graph(); + auto node = build_simple_add_graph(g.get()); + + auto cg = loco::make_graph(); + auto clone = clone_node(node, cg.get()); + + ASSERT_NE(nullptr, clone); + ASSERT_EQ(cg.get(), clone->graph()); + ASSERT_EQ(node->name(), clone->name()); + ASSERT_EQ(node->dtype(), clone->dtype()); + ASSERT_EQ(node->rank(), clone->rank()); + ASSERT_EQ(node->shape_status(), clone->shape_status()); +} + +TEST(CircleNodeCloneTest, clone_node_NEG) +{ + auto g = loco::make_graph(); + auto node = build_simple_add_graph(g.get()); + + auto cg = loco::make_graph(); + auto clone = luci::clone_node(nullptr, cg.get()); + ASSERT_EQ(nullptr, clone); + auto clone2 = luci::clone_node(node, nullptr); + ASSERT_EQ(nullptr, clone2); +} diff --git a/compiler/luci/service/src/CircleShapeInference.cpp b/compiler/luci/service/src/CircleShapeInference.cpp index db8ffd8ad..73472069b 100644 --- a/compiler/luci/service/src/CircleShapeInference.cpp +++ b/compiler/luci/service/src/CircleShapeInference.cpp @@ -15,27 +15,16 @@ */ #include "luci/Service/CircleShapeInference.h" -#include "luci/Service/ShapeDescription.h" + +#include "CircleShapeInferenceHelper.h" #include <loco.h> -#include <loco/Service/ShapeInference.h> #include <luci/Log.h> #include <cassert> #include <iostream> -namespace luci -{ - -ShapeDescription ShapeInference::get(loco::Node *node) -{ - assert(loco::shape_known(node)); - return to_shape_description(loco::shape_get(node)); -} - -} // namespace luci - namespace { @@ -46,7 +35,11 @@ std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape { if (r) os << ","; - os << tensor_shape.dim(r).value(); + + if (tensor_shape.dim(r).known()) + os << tensor_shape.dim(r).value(); + else + os << "?"; } os << "]"; return os; @@ -90,5 +83,5 @@ bool Rule::infer(const luci::CircleNode *circle_node, loco::TensorShape &shape) return true; } -} // namespace ssinf +} // namespace sinf } // namespace luci diff --git a/compiler/luci/service/src/CircleShapeInferenceHelper.cpp b/compiler/luci/service/src/CircleShapeInferenceHelper.cpp index f7eb6c3ec..2009aa59f 100644 --- a/compiler/luci/service/src/CircleShapeInferenceHelper.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceHelper.cpp @@ -14,7 +14,24 @@ * limitations under the License. */ -#include "luci/Service/CircleShapeInferenceHelper.h" +#include "CircleShapeInferenceHelper.h" + +namespace luci +{ + +loco::NodeShape shape_get(const loco::Node *node) +{ + assert(luci::shape_known(node)); + return loco::NodeShape{sinf::circle_shape(loco::must_cast<const luci::CircleNode *>(node))}; +} + +bool shape_known(const loco::Node *node) +{ + return loco::must_cast<const luci::CircleNode *>(node)->shape_status() != + luci::ShapeStatus::UNDEFINED; +} + +} // namespace luci namespace luci { @@ -26,7 +43,7 @@ loco::TensorShape circle_shape(const luci::CircleNode *node) loco::TensorShape shape; shape.rank(node->rank()); for (uint32_t r = 0; r < node->rank(); ++r) - shape.dim(r) = loco::Dimension(node->dim(r).value()); + shape.dim(r) = node->dim(r); return shape; } diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h b/compiler/luci/service/src/CircleShapeInferenceHelper.h index dd6a5a454..7c7ea496c 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h +++ b/compiler/luci/service/src/CircleShapeInferenceHelper.h @@ -17,10 +17,24 @@ #ifndef __LUCI_CIRCLE_SHAPE_INFERENCE_HELPER_H__ #define __LUCI_CIRCLE_SHAPE_INFERENCE_HELPER_H__ +#include <loco/IR/NodeShape.h> #include <loco/IR/TensorShape.h> #include <luci/IR/CircleNodes.h> -#include <luci/IR/CircleShapeSignature.h> + +namespace luci +{ + +// NOTE Functions in this namespace will be removed after new inference +// algorithms are fully implemented. + +// This function is temporary function for deprecating loco::shape_get +loco::NodeShape shape_get(const loco::Node *node); + +// This function is temporary function for deprecating loco::shape_known +bool shape_known(const loco::Node *node); + +} // namespace luci namespace luci { diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index 38ff619ab..c6d8232c3 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -17,6 +17,7 @@ #include "luci/Service/CircleShapeInferenceRule.h" #include "Check.h" +#include "CircleShapeInferenceHelper.h" #include "ShapeInfer_StridedSlice.h" #include <luci/IR/CircleNodes.h> @@ -41,7 +42,11 @@ std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape { if (r) os << ","; - os << tensor_shape.dim(r).value(); + + if (tensor_shape.dim(r).known()) + os << tensor_shape.dim(r).value(); + else + os << "?"; } os << "]"; return os; @@ -52,7 +57,15 @@ loco::TensorShape own_shape(const luci::CircleNode *node) loco::TensorShape shape; shape.rank(node->rank()); for (uint32_t r = 0; r < node->rank(); ++r) - shape.dim(r) = loco::Dimension(node->dim(r).value()); + { + // Shape inference rules in this file did not consider unknown dimension. + // If some node has unknown dimension, 0 is inserted and wrong shape + // inference was done as a result. + // To fix this, new shape inference algorithm is being implemented. + // Until new inference algorithm is fully implemented, unknown dimension + // would be represented as 1 along with TFLite expression. + shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1; + } return shape; } @@ -135,10 +148,8 @@ loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::Tenso output_shape.rank(rank); for (uint32_t axis = 0; axis < rank; ++axis) { - assert(x.dim(axis).known() && y.dim(axis).known()); - - auto x_dim = x.dim(axis).value(); - auto y_dim = y.dim(axis).value(); + auto x_dim = x.dim(axis).known() ? x.dim(axis).value() : 1; + auto y_dim = y.dim(axis).known() ? y.dim(axis).value() : 1; // each dimension of x and y should be same or one must be 1 if different if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1))) @@ -177,23 +188,29 @@ template <loco::DataType T> std::vector<int64_t> vector_from_constant(luci::Circ template <class CIRCLENODE> loco::NodeShape broadcast_xy(const CIRCLENODE *node) { - auto x_shape = loco::shape_get(node->x()).template as<loco::TensorShape>(); - auto y_shape = loco::shape_get(node->y()).template as<loco::TensorShape>(); + auto x_shape = luci::shape_get(node->x()).template as<loco::TensorShape>(); + auto y_shape = luci::shape_get(node->y()).template as<loco::TensorShape>(); auto output_shape = broadcast_shape(x_shape, y_shape); return loco::NodeShape{output_shape}; } +template <class CIRCLENODE> loco::NodeShape use_inputs(const CIRCLENODE *node) +{ + auto inputs_shape = luci::shape_get(node->inputs()).template as<loco::TensorShape>(); + return loco::NodeShape{inputs_shape}; +} + template <class CIRCLENODE> loco::NodeShape use_x(const CIRCLENODE *node) { - auto x_shape = loco::shape_get(node->x()).template as<loco::TensorShape>(); + auto x_shape = luci::shape_get(node->x()).template as<loco::TensorShape>(); return loco::NodeShape{x_shape}; } template <class CIRCLENODE> loco::NodeShape use_logits(const CIRCLENODE *node) { - auto shape = loco::shape_get(node->logits()).template as<loco::TensorShape>(); + auto shape = luci::shape_get(node->logits()).template as<loco::TensorShape>(); return loco::NodeShape{shape}; } @@ -202,7 +219,7 @@ loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *pa { const loco::DataType S32 = loco::DataType::S32; - auto input_shape = loco::shape_get(node->input()).template as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).template as<loco::TensorShape>(); // TODO support other data type LUCI_ASSERT(paddings->dtype() == S32, "Only support int 32 for now"); @@ -232,11 +249,11 @@ loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *pa loco::NodeShape infer_add_n(const luci::CircleAddN *node) { - auto shape = loco::shape_get(node->inputs(0)).as<loco::TensorShape>(); + auto shape = luci::shape_get(node->inputs(0)).as<loco::TensorShape>(); for (uint32_t idx = 1; idx < node->arity(); ++idx) { - auto shape_idx = loco::shape_get(node->inputs(idx)).as<loco::TensorShape>(); + auto shape_idx = luci::shape_get(node->inputs(idx)).as<loco::TensorShape>(); if (!(shape == shape_idx)) { INTERNAL_EXN_V("ADD_N shape not same as the first input: ", idx); @@ -247,8 +264,8 @@ loco::NodeShape infer_add_n(const luci::CircleAddN *node) loco::NodeShape infer_arg_max(const luci::CircleArgMax *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); - auto dimension_shape = loco::shape_get(node->dimension()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); + auto dimension_shape = luci::shape_get(node->dimension()).as<loco::TensorShape>(); int64_t select_axis = 0; { @@ -286,8 +303,8 @@ loco::NodeShape infer_arg_max(const luci::CircleArgMax *node) loco::NodeShape infer_arg_min(const luci::CircleArgMin *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); - auto dimension_shape = loco::shape_get(node->dimension()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); + auto dimension_shape = luci::shape_get(node->dimension()).as<loco::TensorShape>(); int64_t select_axis = 0; { @@ -326,9 +343,7 @@ loco::NodeShape infer_arg_min(const luci::CircleArgMin *node) // Call this for CircleAvgPool2D and CircleMaxPool2D only template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node) { - LUCI_ASSERT(loco::shape_known(node->value()), "Shape must be known"); - - auto ifm_shape = loco::shape_get(node->value()).template as<loco::TensorShape>(); + auto ifm_shape = luci::shape_get(node->value()).template as<loco::TensorShape>(); assert(ifm_shape.rank() == 4); uint32_t input_height = ifm_shape.dim(1).value(); @@ -372,7 +387,7 @@ loco::NodeShape infer_batch_to_space_nd(const luci::CircleBatchToSpaceND *node) { const loco::DataType S32 = loco::DataType::S32; - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); // Support only input rank is 3 and 4 assert(input_shape.rank() == 3 || input_shape.rank() == 4); @@ -384,8 +399,8 @@ loco::NodeShape infer_batch_to_space_nd(const luci::CircleBatchToSpaceND *node) auto const_crops = loco::must_cast<luci::CircleConst *>(node->crops()); LUCI_ASSERT(const_crops->dtype() == loco::DataType::S32, "Only support int32 crops"); - auto const_block_shape_shape = loco::shape_get(const_block_shape).as<loco::TensorShape>(); - auto const_crops_shape = loco::shape_get(const_crops).as<loco::TensorShape>(); + auto const_block_shape_shape = luci::shape_get(const_block_shape).as<loco::TensorShape>(); + auto const_crops_shape = luci::shape_get(const_crops).as<loco::TensorShape>(); assert(const_block_shape_shape.rank() == 1); assert(const_crops_shape.rank() == 2); @@ -423,8 +438,8 @@ struct OutputSize template <class Conv2DType> OutputSize infer_conv2d_type(const Conv2DType *node) { - auto ifm_shape = loco::shape_get(node->input()).template as<loco::TensorShape>(); - auto ker_shape = loco::shape_get(node->filter()).template as<loco::TensorShape>(); + auto ifm_shape = luci::shape_get(node->input()).template as<loco::TensorShape>(); + auto ker_shape = luci::shape_get(node->filter()).template as<loco::TensorShape>(); assert(ifm_shape.rank() == 4); assert(ker_shape.rank() == 4); @@ -496,7 +511,7 @@ loco::NodeShape infer_batchmatmul_shape(const loco::TensorShape &x_shape, loco::Dimension y_lhs = adj_y ? y_shape.dim(y_rank - 1) : y_shape.dim(y_rank - 2); loco::Dimension y_rhs = adj_y ? y_shape.dim(y_rank - 2) : y_shape.dim(y_rank - 1); - if (not(x_rhs == y_lhs)) + if (x_rhs.known() && y_lhs.known() && not(x_rhs == y_lhs)) INTERNAL_EXN("x_rhs and y_lhs should be same"); uint32_t out_rank = output_shape.rank(); @@ -511,7 +526,7 @@ loco::NodeShape infer_concatenation(const luci::CircleConcatenation *node) // TODO Support when CircleConcatenation has 0 input assert(node->numValues() > 0); - auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>(); + auto first_shape = luci::shape_get(node->values(0)).as<loco::TensorShape>(); auto axis = node->axis(); if (axis < 0) axis += first_shape.rank(); @@ -527,14 +542,20 @@ loco::NodeShape infer_concatenation(const luci::CircleConcatenation *node) for (uint32_t i = 1; i < node->numValues(); ++i) { - auto input_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->values(i)).as<loco::TensorShape>(); for (uint32_t j = 0; j < output_shape.rank(); ++j) { if (j == static_cast<uint32_t>(axis)) + { + // If dimension is unknown, value() will return 0. + // This is wrong but until new inference algorithm is implemented, + // this code will not be modified to keep compatibility. output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value(); + } else - assert(output_shape.dim(j) == input_shape.dim(j)); + assert(!output_shape.dim(j).known() || !input_shape.dim(j).known() || + output_shape.dim(j) == input_shape.dim(j)); } } @@ -545,8 +566,8 @@ loco::NodeShape infer_conv2d(const luci::CircleConv2D *node) { LOGGER(l); - auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC - auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI + auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC + auto ker_shape = luci::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI INFO(l) << "[luci] CircleConv2D ShapeInf ifm(" << ifm_shape.rank() << ") ker(" << ker_shape.rank() << ")" << std::endl; @@ -569,7 +590,7 @@ loco::NodeShape infer_conv2d(const luci::CircleConv2D *node) loco::NodeShape infer_depth_to_space(const luci::CircleDepthToSpace *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported"); // Only data format NHWC is supported @@ -601,12 +622,13 @@ loco::NodeShape infer_depth_to_space(const luci::CircleDepthToSpace *node) loco::NodeShape infer_depthwise_conv2d(const luci::CircleDepthwiseConv2D *node) { - auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC - auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM + auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC + auto ker_shape = luci::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM assert(ifm_shape.rank() == 4); assert(ker_shape.rank() == 4); assert(ker_shape.dim(0).value() == 1); + assert(ifm_shape.dim(3).value() * node->depthMultiplier() == ker_shape.dim(3).value()); auto os = infer_conv2d_type(node); @@ -623,7 +645,7 @@ loco::NodeShape infer_depthwise_conv2d(const luci::CircleDepthwiseConv2D *node) loco::NodeShape infer_expand_dims(const luci::CircleExpandDims *node) { const loco::DataType S32 = loco::DataType::S32; - auto x_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto x_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); if (x_shape.rank() == 0) { // This maybe for unknown shape. We use shape from the node itself. @@ -637,7 +659,7 @@ loco::NodeShape infer_expand_dims(const luci::CircleExpandDims *node) } int32_t axis = const_axis->at<S32>(0); LUCI_ASSERT((axis <= static_cast<int32_t>(x_shape.rank())) && - (axis >= -1 - static_cast<int32_t>(x_shape.rank())), + (axis >= -1 - static_cast<int32_t>(x_shape.rank())), "Axis has to be between [-(D+1), D], where D is rank of input."); size_t positive_axis = axis < 0 ? x_shape.rank() + axis + 1 : axis; loco::TensorShape output_shape; @@ -684,8 +706,8 @@ loco::NodeShape infer_fill(const luci::CircleFill *node) loco::NodeShape infer_fully_connected(const luci::CircleFullyConnected *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); - auto weights_shape = loco::shape_get(node->weights()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); + auto weights_shape = luci::shape_get(node->weights()).as<loco::TensorShape>(); // Checking shape capability for fully connected layer // Input: a tensor of at least rank 2 [D1, D2, ... Dn] @@ -715,8 +737,8 @@ loco::NodeShape infer_gather(const luci::CircleGather *node) { loco::TensorShape output_shape; - const auto input_shape = loco::shape_get(node->params()).as<loco::TensorShape>(); - const auto positions_shape = loco::shape_get(node->indices()).as<loco::TensorShape>(); + const auto input_shape = luci::shape_get(node->params()).as<loco::TensorShape>(); + const auto positions_shape = luci::shape_get(node->indices()).as<loco::TensorShape>(); int32_t axis = node->axis(); // If CircleGather input has a dynamic shape, it can't inference this shape. So, it returns the @@ -743,8 +765,8 @@ loco::NodeShape infer_gather_nd(const luci::CircleGatherNd *node) { loco::TensorShape output_shape; - const auto params_shape = loco::shape_get(node->params()).as<loco::TensorShape>(); - const auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>(); + const auto params_shape = luci::shape_get(node->params()).as<loco::TensorShape>(); + const auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>(); const auto params_rank = params_shape.rank(); const auto indices_rank = indices_shape.rank(); @@ -791,7 +813,7 @@ loco::NodeShape infer_matrix_diag(const luci::CircleMatrixDiag *node) { loco::TensorShape output_shape; - auto diagonal_shape = loco::shape_get(node->diagonal()).as<loco::TensorShape>(); + auto diagonal_shape = luci::shape_get(node->diagonal()).as<loco::TensorShape>(); auto rank = diagonal_shape.rank(); output_shape.rank(rank + 1); @@ -808,8 +830,8 @@ loco::NodeShape infer_matrix_diag(const luci::CircleMatrixDiag *node) loco::NodeShape infer_matrix_set_diag(const luci::CircleMatrixSetDiag *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); - auto diagonal_shape = loco::shape_get(node->diagonal()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); + auto diagonal_shape = luci::shape_get(node->diagonal()).as<loco::TensorShape>(); auto rank = diagonal_shape.rank(); @@ -831,7 +853,7 @@ loco::TensorShape infer_reducer(const loco::Node *input, const loco::Node *indic { const loco::DataType S32 = loco::DataType::S32; - auto input_shape = loco::shape_get(input).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(input).as<loco::TensorShape>(); auto reduction_indices = loco::must_cast<const luci::CircleConst *>(indices); { // Exceptions @@ -892,7 +914,7 @@ loco::NodeShape infer_mirror_pad(const luci::CircleMirrorPad *node) loco::NodeShape infer_one_hot(const luci::CircleOneHot *node) { const loco::DataType S32 = loco::DataType::S32; - auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>(); + auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>(); // Only support OneHot node's depth() is CircleConst with type S32 // TODO support depth with other types auto depth = loco::must_cast<luci::CircleConst *>(node->depth()); @@ -925,11 +947,11 @@ loco::NodeShape infer_pack(const luci::CirclePack *node) { LUCI_ASSERT(node->values_count() > 0, "Only support one or more inputs"); - auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>(); + auto first_shape = luci::shape_get(node->values(0)).as<loco::TensorShape>(); // Make sure all inputs have the same shape. for (uint32_t i = 1; i < node->values_count(); ++i) { - auto in_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>(); + auto in_shape = luci::shape_get(node->values(i)).as<loco::TensorShape>(); LUCI_ASSERT(loco::NodeShape{first_shape} == loco::NodeShape{in_shape}, "All inputs must have the same shape"); } @@ -985,8 +1007,8 @@ loco::NodeShape infer_pad_v2(const luci::CirclePadV2 *node) loco::NodeShape infer_p_relu(const luci::CirclePRelu *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); - auto alpha_shape = loco::shape_get(node->alpha()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); + auto alpha_shape = luci::shape_get(node->alpha()).as<loco::TensorShape>(); auto output_shape = broadcast_shape(input_shape, alpha_shape); @@ -1087,10 +1109,12 @@ loco::NodeShape infer_reshape(const luci::CircleReshape *node) loco::TensorShape output_shape = shape_by_input; // One of the dimensions can have special value -1, meaning its actual value should be inferred. - const auto input_shape = loco::shape_get(node->tensor()).as<loco::TensorShape>(); - const uint32_t input_element_count = loco::element_count(&input_shape); + const auto input_shape = luci::shape_get(node->tensor()).as<loco::TensorShape>(); + uint32_t input_element_count = 1; uint32_t output_element_count = 1; uint32_t unknown_dim_index = UINT32_MAX; + for (uint32_t i = 0; i < input_shape.rank(); ++i) + input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1); for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) { const uint32_t dim_value = output_shape.dim(dim_index).value(); @@ -1114,7 +1138,7 @@ loco::NodeShape infer_reshape(const luci::CircleReshape *node) loco::NodeShape infer_resize_bilinear(const luci::CircleResizeBilinear *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); if (input_shape.rank() != 4) INTERNAL_EXN("Expected ResizeBilinear input to have rank 4"); @@ -1142,7 +1166,7 @@ loco::NodeShape infer_resize_bilinear(const luci::CircleResizeBilinear *node) loco::NodeShape infer_resize_nearest_neighbor(const luci::CircleResizeNearestNeighbor *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); if (input_shape.rank() != 4) INTERNAL_EXN("Expected ResizeNearesNeighbor input to have rank 4"); @@ -1195,8 +1219,8 @@ loco::NodeShape infer_scatter_nd(const luci::CircleScatterNd *node) loco::NodeShape infer_segment_sum(const luci::CircleSegmentSum *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); - auto segment_shape = loco::shape_get(node->segment_ids()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); + auto segment_shape = luci::shape_get(node->segment_ids()).as<loco::TensorShape>(); LUCI_ASSERT(segment_shape.rank() == 1, "segment_ids must be 1-D tensor"); LUCI_ASSERT(segment_shape.dim(0).value() == input_shape.dim(0).value(), @@ -1226,11 +1250,11 @@ loco::NodeShape infer_segment_sum(const luci::CircleSegmentSum *node) loco::NodeShape infer_select(const luci::CircleSelect *node) { - auto t_shape = loco::shape_get(node->t()).as<loco::TensorShape>(); - assert(t_shape == loco::shape_get(node->e()).as<loco::TensorShape>()); + auto t_shape = luci::shape_get(node->t()).as<loco::TensorShape>(); + assert(t_shape == luci::shape_get(node->e()).as<loco::TensorShape>()); // condition shape validation - auto c_shape = loco::shape_get(node->condition()).as<loco::TensorShape>(); + auto c_shape = luci::shape_get(node->condition()).as<loco::TensorShape>(); if (c_shape.rank() != t_shape.rank()) { if (c_shape.rank() != 0 && c_shape.rank() != 1) @@ -1248,9 +1272,9 @@ loco::NodeShape infer_select(const luci::CircleSelect *node) loco::NodeShape infer_select_v2(const luci::CircleSelectV2 *node) { - auto c_shape = loco::shape_get(node->condition()).as<loco::TensorShape>(); - auto t_shape = loco::shape_get(node->t()).as<loco::TensorShape>(); - auto e_shape = loco::shape_get(node->e()).as<loco::TensorShape>(); + auto c_shape = luci::shape_get(node->condition()).as<loco::TensorShape>(); + auto t_shape = luci::shape_get(node->t()).as<loco::TensorShape>(); + auto e_shape = luci::shape_get(node->e()).as<loco::TensorShape>(); // validate ability to broadcast shapes to each other auto b_shape = broadcast_shape(broadcast_shape(c_shape, t_shape), e_shape); @@ -1259,7 +1283,7 @@ loco::NodeShape infer_select_v2(const luci::CircleSelectV2 *node) loco::NodeShape infer_shape(const luci::CircleShape *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); loco::TensorShape output_shape; @@ -1274,7 +1298,7 @@ loco::NodeShape infer_slice(const luci::CircleSlice *node) const loco::DataType S32 = loco::DataType::S32; const loco::DataType S64 = loco::DataType::S64; - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); auto const_begin = loco::must_cast<luci::CircleConst *>(node->begin()); auto const_size = loco::must_cast<luci::CircleConst *>(node->size()); @@ -1318,7 +1342,7 @@ loco::NodeShape infer_space_to_batch_nd(const luci::CircleSpaceToBatchND *node) { const loco::DataType S32 = loco::DataType::S32; - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); // Support only input rank is 3 and 4 assert(input_shape.rank() == 3 || input_shape.rank() == 4); @@ -1330,8 +1354,8 @@ loco::NodeShape infer_space_to_batch_nd(const luci::CircleSpaceToBatchND *node) auto const_paddings = loco::must_cast<luci::CircleConst *>(node->paddings()); LUCI_ASSERT(const_paddings->dtype() == S32, "Only support int32 paddings"); - auto const_block_shape_shape = loco::shape_get(const_block_shape).as<loco::TensorShape>(); - auto const_paddings_shape = loco::shape_get(const_paddings).as<loco::TensorShape>(); + auto const_block_shape_shape = luci::shape_get(const_block_shape).as<loco::TensorShape>(); + auto const_paddings_shape = luci::shape_get(const_paddings).as<loco::TensorShape>(); assert(const_block_shape_shape.rank() == 1); assert(const_paddings_shape.rank() == 2); @@ -1374,7 +1398,7 @@ loco::NodeShape infer_space_to_batch_nd(const luci::CircleSpaceToBatchND *node) loco::NodeShape infer_space_to_depth(const luci::CircleSpaceToDepth *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported"); // Only data format NHWC is supported @@ -1412,19 +1436,33 @@ loco::NodeShape infer_sparse_to_dense(const luci::CircleSparseToDense *node) auto output_shape_node = dynamic_cast<luci::CircleConst *>(node->output_shape()); if (output_shape_node != nullptr) { - // Only support node with S32 - LUCI_ASSERT(output_shape_node->dtype() == loco::DataType::S32, - "Only support int32 CircleConst"); + const auto output_shape_type = output_shape_node->dtype(); if (output_shape_node->rank() != 1) INTERNAL_EXN_V("Only support rank 1 CircleConst", oops::to_uint32(output_shape_node->rank())); - shape.rank(output_shape_node->size<loco::DataType::S32>()); + if (output_shape_type == loco::DataType::S32) + { + shape.rank(output_shape_node->size<loco::DataType::S32>()); - for (uint32_t axis = 0; axis < shape.rank(); ++axis) + for (uint32_t axis = 0; axis < shape.rank(); ++axis) + { + shape.dim(axis) = output_shape_node->at<loco::DataType::S32>(axis); + } + } + else if (output_shape_type == loco::DataType::S64) { - shape.dim(axis) = output_shape_node->at<loco::DataType::S32>(axis); + shape.rank(output_shape_node->size<loco::DataType::S64>()); + + for (uint32_t axis = 0; axis < shape.rank(); ++axis) + { + shape.dim(axis) = output_shape_node->at<loco::DataType::S64>(axis); + } + } + else + { + INTERNAL_EXN("Output shape of SparseToDense must be either int32 or int64"); } } else @@ -1453,7 +1491,7 @@ loco::NodeShape infer_strided_slice(const luci::CircleStridedSlice *node) loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); // TODO input shape may be unknown before runtime std::vector<bool> do_squeeze(input_shape.rank(), false); @@ -1508,7 +1546,7 @@ loco::NodeShape infer_tile(const luci::CircleTile *node) { const loco::DataType S32 = loco::DataType::S32; - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); auto multiples = loco::must_cast<luci::CircleConst *>(node->multiples()); // TODO support non-const case @@ -1534,7 +1572,7 @@ loco::NodeShape infer_tile(const luci::CircleTile *node) loco::NodeShape infer_transpose(const luci::CircleTranspose *node) { - auto input_shape = loco::shape_get(node->a()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->a()).as<loco::TensorShape>(); auto perm_node = loco::must_cast<luci::CircleConst *>(node->perm()); @@ -1576,7 +1614,7 @@ loco::NodeShape infer_unpack(const luci::CircleUnpack *node) // CircleUnpack provides list(array) of Tensors which has one less dimension of the input // We'll set shape of CircleUnpack to shape of actual outputs // TODO fix this if any problem rises - auto value_shape = loco::shape_get(node->value()).as<loco::TensorShape>(); + auto value_shape = luci::shape_get(node->value()).as<loco::TensorShape>(); auto axis = node->axis(); auto num = node->num(); @@ -1610,9 +1648,9 @@ loco::NodeShape infer_unpack(const luci::CircleUnpack *node) loco::NodeShape infer_unidirectionalsequencelstm(const luci::CircleUnidirectionalSequenceLSTM *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); auto recurrent_to_output_weights = - loco::shape_get(node->recurrent_to_output_weights()).as<loco::TensorShape>(); + luci::shape_get(node->recurrent_to_output_weights()).as<loco::TensorShape>(); auto rank = input_shape.rank(); loco::TensorShape output_shape; output_shape.rank(rank); @@ -1626,7 +1664,7 @@ loco::NodeShape infer_unidirectionalsequencelstm(const luci::CircleUnidirectiona loco::NodeShape infer_unique(const luci::CircleUnique *node) { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); assert(input_shape.rank() == 1); @@ -1641,7 +1679,7 @@ loco::NodeShape infer_bcq_fully_connected(const luci::CircleBCQFullyConnected *n { loco::TensorShape out_shape; - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); auto weights_clusters = loco::must_cast<luci::CircleConst *>(node->weights_clusters()); LUCI_ASSERT(input_shape.rank() == 2, "Input rank of BCQFullyConnected should be 2"); @@ -1664,8 +1702,8 @@ loco::NodeShape infer_bcq_gather(const luci::CircleBCQGather *node) loco::TensorShape input_shape; loco::TensorShape output_shape; - const auto input_binary_shape = loco::shape_get(node->input_binary()).as<loco::TensorShape>(); - const auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>(); + const auto input_binary_shape = luci::shape_get(node->input_binary()).as<loco::TensorShape>(); + const auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>(); auto axis = node->axis(); auto input_clusters = loco::must_cast<luci::CircleConst *>(node->input_clusters()); @@ -1712,46 +1750,6 @@ loco::NodeShape infer_output(const luci::CircleOutput *node) return loco::NodeShape{*output_shape}; } -loco::NodeShape infer_if_out(const luci::CircleIfOut *node) -{ - /** - * @note IF operator type and shape are that of the "then" and "else" - * Graph Outputs. - */ - auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input()); - if (circle_if == nullptr) - { - INTERNAL_EXN("CircleIf IR is not configured correctly"); - } - - auto index = node->index(); - auto then_graph = circle_if->then_graph(); - auto else_graph = circle_if->else_graph(); - assert(then_graph != nullptr); - assert(else_graph != nullptr); - - // shape and type are assumed to be same - // these are checked at post_import_graph() in Import - auto then_outputs = loco::output_nodes(then_graph); - auto else_outputs = loco::output_nodes(else_graph); - assert(then_outputs.size() == else_outputs.size()); - assert(index < static_cast<int32_t>(then_outputs.size())); - - auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index)); - auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index)); - - auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items - auto else_graph_outputs = else_graph->outputs(); - assert(then_graph_outputs->size() == else_graph_outputs->size()); - - auto then_graph_output = then_graph_outputs->at(then_out->index()); - auto else_graph_output = else_graph_outputs->at(else_out->index()); - (void)else_graph_output; // make compiler happy for unused variable warnings - assert(*then_graph_output->shape() == *else_graph_output->shape()); - - return loco::NodeShape{*then_graph_output->shape()}; -} - loco::NodeShape infer_non_max_suppression_v4_out(const luci::CircleNonMaxSuppressionV4Out *node) { const loco::DataType S32 = loco::DataType::S32; @@ -1818,7 +1816,7 @@ loco::NodeShape infer_split_out(const luci::CircleSplitOut *node) loco::NodeShape unknown; - auto split_shape = loco::shape_get(split).as<loco::TensorShape>(); + auto split_shape = luci::shape_get(split).as<loco::TensorShape>(); auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim()); if (split_dim == nullptr) @@ -1852,7 +1850,7 @@ loco::NodeShape infer_split_v_out(const luci::CircleSplitVOut *node) loco::NodeShape unknown; - auto split_shape = loco::shape_get(split).as<loco::TensorShape>(); + auto split_shape = luci::shape_get(split).as<loco::TensorShape>(); auto size_splits = dynamic_cast<const luci::CircleConst *>(split->size_splits()); if (size_splits == nullptr) @@ -1913,7 +1911,7 @@ loco::NodeShape infer_top_k_v2_out(const luci::CircleTopKV2Out *node) INTERNAL_EXN("CircleSplit IR is not configured correctly"); // shape of topkv2 is same as topkv2->input() - auto input_shape = loco::shape_get(topkv2).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(topkv2).as<loco::TensorShape>(); auto node_k = loco::must_cast<const luci::CircleConst *>(topkv2->k()); LUCI_ASSERT(node_k->dtype() == S32, "Only support Int32"); @@ -1940,7 +1938,7 @@ loco::NodeShape infer_unique_out(const luci::CircleUniqueOut *node) } assert(node->index() == 1); auto unique = loco::must_cast<luci::CircleUnique *>(node->input()); - auto unique_shape = loco::shape_get(unique->input()).as<loco::TensorShape>(); + auto unique_shape = luci::shape_get(unique->input()).as<loco::TensorShape>(); assert(unique_shape.rank() == 1); @@ -1958,7 +1956,7 @@ loco::NodeShape infer_unpack_out(const luci::CircleUnpackOut *node) INTERNAL_EXN("CircleUnpack IR is not configured correctly"); } - auto unpack_shape = loco::shape_get(unpack).as<loco::TensorShape>(); + auto unpack_shape = luci::shape_get(unpack).as<loco::TensorShape>(); return loco::NodeShape{unpack_shape}; } @@ -2025,8 +2023,8 @@ public: loco::NodeShape visit(const luci::CircleBatchMatMul *node) final { - auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>(); - auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>(); + auto x_shape = luci::shape_get(node->x()).as<loco::TensorShape>(); + auto y_shape = luci::shape_get(node->y()).as<loco::TensorShape>(); return infer_batchmatmul_shape(x_shape, y_shape, node->adj_x(), node->adj_y()); } @@ -2065,7 +2063,7 @@ public: loco::NodeShape visit(const luci::CircleDequantize *node) final { - const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } @@ -2073,7 +2071,7 @@ public: loco::NodeShape visit(const luci::CircleElu *node) final { - auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } @@ -2087,6 +2085,8 @@ public: return infer_expand_dims(node); } + loco::NodeShape visit(const luci::CircleFakeQuant *node) final { return use_inputs(node); } + loco::NodeShape visit(const luci::CircleFill *node) final { return infer_fill(node); } loco::NodeShape visit(const luci::CircleFloor *node) final { return use_x(node); } @@ -2112,7 +2112,7 @@ public: { // Shape of CircleIf is not used. Just use input 0 assert(node->input_count() > 0); - const auto input_shape = loco::shape_get(node->input(0)).as<loco::TensorShape>(); + const auto input_shape = luci::shape_get(node->input(0)).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } @@ -2125,7 +2125,7 @@ public: loco::NodeShape visit(const luci::CircleLeakyRelu *node) final { - const auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>(); + const auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } @@ -2135,7 +2135,7 @@ public: loco::NodeShape visit(const luci::CircleLocalResponseNormalization *node) final { - const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } @@ -2184,13 +2184,13 @@ public: loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4 *node) final { - const auto boxes_shape = loco::shape_get(node->boxes()).as<loco::TensorShape>(); + const auto boxes_shape = luci::shape_get(node->boxes()).as<loco::TensorShape>(); return loco::NodeShape{boxes_shape}; } loco::NodeShape visit(const luci::CircleNonMaxSuppressionV5 *node) final { - const auto boxes_shape = loco::shape_get(node->boxes()).as<loco::TensorShape>(); + const auto boxes_shape = luci::shape_get(node->boxes()).as<loco::TensorShape>(); return loco::NodeShape{boxes_shape}; } @@ -2244,21 +2244,21 @@ public: loco::NodeShape visit(const luci::CircleRelu *node) final { - auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } loco::NodeShape visit(const luci::CircleRelu6 *node) final { - auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } loco::NodeShape visit(const luci::CircleReluN1To1 *node) final { - auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } @@ -2284,7 +2284,7 @@ public: loco::NodeShape visit(const luci::CircleReverseSequence *node) final { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } @@ -2293,9 +2293,9 @@ public: loco::NodeShape visit(const luci::CircleReverseV2 *node) final { - auto input_shape = loco::shape_get(node->tensor()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->tensor()).as<loco::TensorShape>(); - LUCI_ASSERT(loco::shape_get(node->axis()).as<loco::TensorShape>().rank() == 1, + LUCI_ASSERT(luci::shape_get(node->axis()).as<loco::TensorShape>().rank() == 1, "Tensor must be 1-D"); return loco::NodeShape{input_shape}; @@ -2340,14 +2340,14 @@ public: loco::NodeShape visit(const luci::CircleSplit *node) final { // We'll set Split output as same as input so that SplitOut can handle it's own shape - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } loco::NodeShape visit(const luci::CircleSplitV *node) final { // We'll set SplitV output as same as input so that SplitOut can handle it's own shape - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } @@ -2382,7 +2382,7 @@ public: loco::NodeShape visit(const luci::CircleTopKV2 *node) final { // set shape of this node as same as input - const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } @@ -2408,13 +2408,13 @@ public: { // Shape of CircleWhile is not used. Just use input 0 assert(node->arity() > 0); - const auto input_shape = loco::shape_get(node->input(0)).as<loco::TensorShape>(); + const auto input_shape = luci::shape_get(node->input(0)).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } loco::NodeShape visit(const luci::CircleZerosLike *node) final { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } @@ -2429,7 +2429,7 @@ public: loco::NodeShape visit(const luci::CircleInstanceNorm *node) final { - auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); return loco::NodeShape{input_shape}; } @@ -2445,8 +2445,6 @@ public: loco::NodeShape visit(const luci::CircleCustomOut *node) final { return use_own(node); } - loco::NodeShape visit(const luci::CircleIfOut *node) final { return infer_if_out(node); } - loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final { return infer_non_max_suppression_v4_out(node); diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.test.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.test.cpp deleted file mode 100644 index ac27db3bd..000000000 --- a/compiler/luci/service/src/CircleShapeInferenceRule.test.cpp +++ /dev/null @@ -1,626 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "TestGraph.h" -#include "luci/Service/CircleShapeInferenceRule.h" - -#include <luci/IR/CircleNodes.h> -#include <luci/IR/CircleDialect.h> - -#include <loco.h> -#include <loco/IR/CanonicalDialect.h> -#include <loco/Service/ShapeInference.h> -#include <loco/Service/CanonicalShapeInferenceRule.h> -#include <loco/Service/MultiDialectShapeInferenceRule.h> - -#include <oops/InternalExn.h> - -#include <gtest/gtest.h> - -#include <memory> - -namespace -{ - -bool shape_pass(loco::Graph *g) -{ - loco::CanonicalShapeInferenceRule canonical_rule; - luci::CircleShapeInferenceRule circle_rule; - loco::MultiDialectShapeInferenceRule rules; - - rules.bind(loco::CanonicalDialect::get(), &canonical_rule) - .bind(luci::CircleDialect::get(), &circle_rule); - - return loco::apply(&rules).to(g); -} - -} // namespace - -TEST(CircleShapeInferenceRuleTest, minimal_with_CircleRelu) -{ - // Create a simple network - luci::test::TestGraph graph; - auto relu_node = graph.append<luci::CircleRelu>(graph.input_node); - graph.complete(relu_node); - - // set shape - { - graph.input_node->rank(2); - graph.input_node->dim(0) = 3; - graph.input_node->dim(1) = 4; - - graph.output_node->rank(2); - graph.output_node->dim(0) = 3; - graph.output_node->dim(1) = 4; - - luci::test::graph_input_shape(graph.input_node); - luci::test::graph_output_shape(graph.output_node); - } - - // pre-check - ASSERT_FALSE(loco::shape_known(relu_node)); - - // shape inference - while (shape_pass(graph.graph()) == true) - ; - - // Verify - { - ASSERT_TRUE(loco::shape_known(relu_node)); - ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(relu_node).domain()); - - auto shape = loco::shape_get(relu_node).as<loco::TensorShape>(); - ASSERT_EQ(2, shape.rank()); - ASSERT_EQ(3, shape.dim(0)); - ASSERT_EQ(4, shape.dim(1)); - } -} - -// based on the case shown in -// https://www.corvil.com/kb/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-tensorflow -TEST(CircleShapeInferenceRuleTest, avgpool2d_valid) -{ - luci::test::TestGraph graph; - auto avg_node = graph.append<luci::CircleAveragePool2D>(graph.input_node); - graph.complete(); - - auto input_node = graph.input_node; - { - input_node->shape({1, 4, 3, 1}); - luci::test::graph_input_shape(input_node); - } - auto output_node = graph.output_node; - { - output_node->shape({1, 2, 1, 1}); - luci::test::graph_output_shape(output_node); - } - // setting CircleAveragePool2D - { - avg_node->filter()->h(2); - avg_node->filter()->w(2); - avg_node->stride()->h(2); - avg_node->stride()->w(2); - avg_node->fusedActivationFunction(luci::FusedActFunc::NONE); - avg_node->padding(luci::Padding::VALID); - } - ASSERT_FALSE(loco::shape_known(avg_node)); - - // shape inference - while (shape_pass(graph.graph()) == true) - ; - - // Verify - { - ASSERT_TRUE(loco::shape_known(avg_node)); - ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(avg_node).domain()); - - auto shape = loco::shape_get(avg_node).as<loco::TensorShape>(); - ASSERT_EQ(4, shape.rank()); - ASSERT_EQ(1, shape.dim(0).value()); - ASSERT_EQ(2, shape.dim(1).value()); - ASSERT_EQ(1, shape.dim(2).value()); - ASSERT_EQ(1, shape.dim(3).value()); - } -} - -TEST(CircleShapeInferenceRuleTest, avgpool2d_same) -{ - luci::test::TestGraph graph; - auto avg_node = graph.append<luci::CircleAveragePool2D>(graph.input_node); - graph.complete(); - - auto input_node = graph.input_node; - { - input_node->shape({1, 4, 3, 1}); - luci::test::graph_input_shape(input_node); - } - auto output_node = graph.output_node; - { - output_node->shape({1, 2, 2, 1}); - luci::test::graph_output_shape(output_node); - } - - // setting CircleAveragePool2D - { - avg_node->filter()->h(2); - avg_node->filter()->w(2); - avg_node->stride()->h(2); - avg_node->stride()->w(2); - avg_node->fusedActivationFunction(luci::FusedActFunc::NONE); - avg_node->padding(luci::Padding::SAME); - } - - ASSERT_FALSE(loco::shape_known(avg_node)); - - // shape inference - while (shape_pass(graph.graph()) == true) - ; - - // Verify - { - ASSERT_TRUE(loco::shape_known(avg_node)); - ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(avg_node).domain()); - - auto shape = loco::shape_get(avg_node).as<loco::TensorShape>(); - ASSERT_EQ(4, shape.rank()); - ASSERT_EQ(1, shape.dim(0).value()); - ASSERT_EQ(2, shape.dim(1).value()); - ASSERT_EQ(2, shape.dim(2).value()); - ASSERT_EQ(1, shape.dim(3).value()); - } -} - -/** - * @note Function to test: Shape inference of two different input shapes - * - * Rank expansion to higher input side - * x(2,1,5) + y(3,5) --> x(2,1,5) + y(1,3,5) - * Do output shape inference like numpy - * x(2,1,5) + y(1,3,5) --> output(2,3,5) - * For each axis, dim value should be same OR one of them should be 1 - */ -TEST(CircleShapeInferenceRuleTest, TFAdd_shapeinf_different) -{ - auto g = loco::make_graph(); - - auto x_node = g->nodes()->create<luci::CircleInput>(); - { - x_node->rank(3); - x_node->dim(0) = 2; - x_node->dim(1) = 1; - x_node->dim(2) = 5; - } - auto y_node = g->nodes()->create<luci::CircleInput>(); - { - y_node->rank(2); - y_node->dim(0) = 3; - y_node->dim(1) = 5; - } - auto add_node = g->nodes()->create<luci::CircleAdd>(); - { - add_node->x(x_node); - add_node->y(y_node); - } - auto output_node = g->nodes()->create<luci::CircleOutput>(); - { - output_node->from(add_node); - } - - auto x_input = g->inputs()->create(); - { - x_input->name("x"); - luci::link(x_input, x_node); - } - auto y_input = g->inputs()->create(); - { - y_input->name("y"); - luci::link(y_input, y_node); - } - auto output = g->outputs()->create(); - { - output->name("output"); - luci::link(output, output_node); - } - - luci::test::graph_input_shape(x_node); - luci::test::graph_input_shape(y_node); - luci::test::graph_output_shape(output_node); - - // pre-check - ASSERT_FALSE(loco::shape_known(add_node)); - - // shape inference - while (shape_pass(g.get()) == true) - ; - - // Verify - { - ASSERT_TRUE(loco::shape_known(add_node)); - ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(add_node).domain()); - - auto shape = loco::shape_get(add_node).as<loco::TensorShape>(); - ASSERT_EQ(3, shape.rank()); - ASSERT_EQ(2, shape.dim(0)); - ASSERT_EQ(3, shape.dim(1)); - ASSERT_EQ(5, shape.dim(2)); - } -} - -TEST(CircleShapeInferenceRuleTest, CircleTranspose_simple) -{ - luci::test::ExampleGraph<luci::test::ExampleGraphType::CircleTranspose> g; - - g.input_node->rank(3); - g.input_node->dim(0) = 3; - g.input_node->dim(1) = 8; - g.input_node->dim(2) = 1; - - g.const_perm->dtype(loco::DataType::S32); - g.const_perm->rank(1); - g.const_perm->dim(0) = 3; - g.const_perm->size<loco::DataType::S32>(3); - g.const_perm->at<loco::DataType::S32>(0) = 1; - g.const_perm->at<loco::DataType::S32>(1) = 2; - g.const_perm->at<loco::DataType::S32>(2) = 0; - - luci::test::graph_input_shape(g.input_node); - luci::test::graph_output_shape(g.output_node); - - // pre-check - ASSERT_FALSE(loco::shape_known(g.transpose_node)); - - // shape inference - while (shape_pass(g.graph()) == true) - ; - - // Verify - { - ASSERT_TRUE(loco::shape_known(g.transpose_node)); - - auto shape = loco::shape_get(g.transpose_node).as<loco::TensorShape>(); - ASSERT_EQ(3, shape.rank()); - ASSERT_EQ(8, shape.dim(0)); - ASSERT_EQ(1, shape.dim(1)); - ASSERT_EQ(3, shape.dim(2)); - } -} - -TEST(CircleShapeInferenceRuleTest, CircleSqueeze) -{ - luci::test::TestGraph graph; - auto squeeze_node = graph.append<luci::CircleSqueeze>(graph.input_node); - graph.complete(); - - auto input_node = graph.input_node; - { - input_node->shape({1, 4, 3, 1}); - } - auto output_node = graph.output_node; - { - output_node->shape({4, 3, 1}); - } - - luci::test::graph_input_shape(input_node); - luci::test::graph_output_shape(output_node); - - squeeze_node->squeeze_dims({0}); - - // pre-check - ASSERT_FALSE(loco::shape_known(squeeze_node)); - - // shape inference - while (shape_pass(graph.graph()) == true) - ; - - // Verify - { - ASSERT_TRUE(loco::shape_known(squeeze_node)); - - auto shape = loco::shape_get(squeeze_node).as<loco::TensorShape>(); - ASSERT_EQ(3, shape.rank()); - ASSERT_EQ(4, shape.dim(0)); - ASSERT_EQ(3, shape.dim(1)); - ASSERT_EQ(1, shape.dim(2)); - } -} - -TEST(CircleShapeInferenceRuleTest, CircleExpandDims) -{ - luci::test::TestGraph graph; - auto axis = graph.append<luci::CircleConst>(); - axis->dtype(loco::DataType::S32); - axis->rank(0); - axis->size<loco::DataType::S32>(1); - axis->at<loco::DataType::S32>(0) = 1; - - auto expand_dims = graph.append<luci::CircleExpandDims>(graph.input_node, axis); - graph.complete(); - - auto input_node = graph.input_node; - { - input_node->shape({4, 3}); - } - - auto output_node = graph.output_node; - { - output_node->from(expand_dims); - } - - luci::test::graph_input_shape(input_node); - luci::test::graph_output_shape(output_node); - - // shape inference - while (shape_pass(graph.graph())) - ; - - // validation - { - ASSERT_TRUE(loco::shape_known(expand_dims)); - - auto shape = loco::shape_get(expand_dims).as<loco::TensorShape>(); - - ASSERT_EQ(3, shape.rank()); - ASSERT_EQ(4, shape.dim(0)); - ASSERT_EQ(1, shape.dim(1)); - ASSERT_EQ(3, shape.dim(2)); - } -} - -TEST(CircleShapeInferenceRuleTest, CircleSqueezeAll) -{ - luci::test::TestGraph graph; - auto squeeze_node = graph.append<luci::CircleSqueeze>(graph.input_node); - graph.complete(); - - auto input_node = graph.input_node; - { - input_node->shape({1, 4, 3, 1}); - } - auto output_node = graph.output_node; - { - input_node->shape({4, 3}); - } - - luci::test::graph_input_shape(input_node); - luci::test::graph_output_shape(output_node); - - squeeze_node->squeeze_dims({}); - - // pre-check - ASSERT_FALSE(loco::shape_known(squeeze_node)); - - // shape inference - while (shape_pass(graph.graph()) == true) - ; - - // Verify - { - ASSERT_TRUE(loco::shape_known(squeeze_node)); - - auto shape = loco::shape_get(squeeze_node).as<loco::TensorShape>(); - ASSERT_EQ(2, shape.rank()); - ASSERT_EQ(4, shape.dim(0)); - ASSERT_EQ(3, shape.dim(1)); - } -} - -TEST(CircleShapeInferenceRuleTest, CircleGatherNd_simple) -{ - luci::test::TestGraph graph; - auto indices_const = graph.append<luci::CircleConst>(); - auto gather_nd_node = graph.append<luci::CircleGatherNd>(graph.input_node, indices_const); - graph.complete(); - - { - auto input_node = graph.input_node; - input_node->shape({1, 4, 4, 3}); - luci::test::graph_input_shape(input_node); - } - { - auto output_node = graph.output_node; - output_node->shape({1, 2, 2, 3}); - luci::test::graph_output_shape(output_node); - } - - { - indices_const->shape({1, 2, 3}); - } - - // pre-check - ASSERT_FALSE(loco::shape_known(gather_nd_node)); - - // shape inference - while (shape_pass(graph.graph()) == true) - ; - - // Verify - { - ASSERT_TRUE(loco::shape_known(gather_nd_node)); - - auto shape = loco::shape_get(gather_nd_node).as<loco::TensorShape>(); - ASSERT_EQ(3, shape.rank()); - ASSERT_EQ(1, shape.dim(0)); - ASSERT_EQ(2, shape.dim(1)); - ASSERT_EQ(3, shape.dim(2)); - } -} - -TEST(CircleShapeInferenceRuleTest, CircleGatherNd_slices) -{ - luci::test::TestGraph graph; - auto indices_const = graph.append<luci::CircleConst>(); - auto gather_nd_node = graph.append<luci::CircleGatherNd>(graph.input_node, indices_const); - graph.complete(); - - { - auto input_node = graph.input_node; - input_node->shape({1, 4, 4, 3}); - luci::test::graph_input_shape(input_node); - } - { - auto output_node = graph.output_node; - output_node->shape({1, 2, 4, 4, 3}); - luci::test::graph_output_shape(output_node); - } - - { - indices_const->shape({1, 2, 1}); - } - - // pre-check - ASSERT_FALSE(loco::shape_known(gather_nd_node)); - - // shape inference - while (shape_pass(graph.graph()) == true) - ; - - // Verify - { - ASSERT_TRUE(loco::shape_known(gather_nd_node)); - - auto shape = loco::shape_get(gather_nd_node).as<loco::TensorShape>(); - ASSERT_EQ(5, shape.rank()); - ASSERT_EQ(1, shape.dim(0)); - ASSERT_EQ(2, shape.dim(1)); - ASSERT_EQ(4, shape.dim(2)); - ASSERT_EQ(4, shape.dim(3)); - ASSERT_EQ(3, shape.dim(4)); - } -} - -TEST(CircleShapeInferenceRuleTest, CircleGatherNd_NEG) -{ - luci::test::TestGraph graph; - auto indices_const = graph.append<luci::CircleConst>(); - auto gather_nd_node = graph.append<luci::CircleGatherNd>(graph.input_node, indices_const); - graph.complete(); - - { - auto input_node = graph.input_node; - input_node->shape({1, 4, 4, 3}); - luci::test::graph_input_shape(input_node); - } - { - // Does not matter, because test should fail anyway - auto output_node = graph.output_node; - output_node->shape({0, 0, 0}); - luci::test::graph_output_shape(output_node); - } - - { - indices_const->shape({1, 2, 5}); - } - - // pre-check - ASSERT_FALSE(loco::shape_known(gather_nd_node)); - - // had to pack into lambda to check throw - auto lambda = [&]() { - // shape inference - while (shape_pass(graph.graph()) == true) - ; - }; - - ASSERT_THROW(lambda(), oops::InternalExn); -} - -TEST(CircleShapeInferenceRuleTest, CircleResizeNearestNeighbor) -{ - luci::test::TestGraph graph; - auto size_const = graph.append<luci::CircleConst>(); - size_const->dtype(loco::DataType::S32); - size_const->rank(1); - size_const->dim(0) = 2; - size_const->size<loco::DataType::S32>(2); - size_const->at<loco::DataType::S32>(0) = 16; - size_const->at<loco::DataType::S32>(1) = 16; - auto resize_node = graph.append<luci::CircleResizeNearestNeighbor>(graph.input_node, size_const); - graph.complete(); - - { - auto input_node = graph.input_node; - input_node->shape({1, 4, 4, 3}); - luci::test::graph_input_shape(input_node); - } - { - auto output_node = graph.output_node; - output_node->from(resize_node); - luci::test::graph_output_shape(output_node); - } - - // pre-check - ASSERT_FALSE(loco::shape_known(resize_node)); - - // shape inference - while (shape_pass(graph.graph()) == true) - ; - - // Verify - { - ASSERT_TRUE(loco::shape_known(resize_node)); - - auto shape = loco::shape_get(resize_node).as<loco::TensorShape>(); - ASSERT_EQ(4, shape.rank()); - ASSERT_EQ(1, shape.dim(0)); - ASSERT_EQ(16, shape.dim(1)); - ASSERT_EQ(16, shape.dim(2)); - ASSERT_EQ(3, shape.dim(3)); - } -} - -TEST(CircleShapeInferenceRuleTest, CircleResizeBilinear) -{ - luci::test::TestGraph graph; - auto size_const = graph.append<luci::CircleConst>(); - size_const->dtype(loco::DataType::S32); - size_const->rank(1); - size_const->dim(0) = 2; - size_const->size<loco::DataType::S32>(2); - size_const->at<loco::DataType::S32>(0) = 16; - size_const->at<loco::DataType::S32>(1) = 16; - auto resize_node = graph.append<luci::CircleResizeBilinear>(graph.input_node, size_const); - graph.complete(); - - { - auto input_node = graph.input_node; - input_node->shape({1, 4, 4, 3}); - luci::test::graph_input_shape(input_node); - } - { - auto output_node = graph.output_node; - output_node->from(resize_node); - luci::test::graph_output_shape(output_node); - } - - // pre-check - ASSERT_FALSE(loco::shape_known(resize_node)); - - // shape inference - while (shape_pass(graph.graph()) == true) - ; - - // Verify - { - ASSERT_TRUE(loco::shape_known(resize_node)); - - auto shape = loco::shape_get(resize_node).as<loco::TensorShape>(); - ASSERT_EQ(4, shape.rank()); - ASSERT_EQ(1, shape.dim(0)); - ASSERT_EQ(16, shape.dim(1)); - ASSERT_EQ(16, shape.dim(2)); - ASSERT_EQ(3, shape.dim(3)); - } -} diff --git a/compiler/luci/service/src/CircleShapeSignatureInference.cpp b/compiler/luci/service/src/CircleShapeSignatureInference.cpp deleted file mode 100644 index 1ccaa19d5..000000000 --- a/compiler/luci/service/src/CircleShapeSignatureInference.cpp +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "luci/Service/CircleShapeSignatureInference.h" - -#include <luci/Log.h> - -namespace -{ - -std::ostream &operator<<(std::ostream &os, const luci::ShapeSignature &shape_signature) -{ - os << "["; - for (uint32_t r = 0; r < shape_signature.rank(); ++r) - { - if (r) - os << ","; - os << shape_signature.dim(r); - } - os << "]"; - return os; -} - -} // namespace - -namespace luci -{ - -namespace ssinf -{ - -bool Rule::infer(const luci::CircleNode *circle_node, ShapeSignature &shape_signature) const -{ - LOGGER(l); - - // There is nothing to check before ShapeSignatureInference. - - Algorithm alg; - - shape_signature = circle_node->accept(&alg); - - VERBOSE(l, 1) << "[luci] Shape Signature( " << circle_node->name() << " )"; - VERBOSE(l, 1) << " before: " << circle_node->shape_signature(); - VERBOSE(l, 1) << " after: " << shape_signature; - - return true; -} - -} // namespace ssinf - -} // namespace luci diff --git a/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp b/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp deleted file mode 100644 index d7d1a24e8..000000000 --- a/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "luci/Service/CircleShapeSignatureInferenceHelper.h" - -#include <loco.h> - -#include <luci/Log.h> - -#include <oops/InternalExn.h> - -namespace luci -{ - -namespace ssinf -{ - -luci::ShapeSignature legalized_signature(const luci::ShapeSignature &signature) -{ - // If shape signature has at least one -1, it is not static. - for (uint32_t i = 0; i < signature.rank(); ++i) - if (signature.dim(i) == -1) - return signature; - - // If all dimensions are static, return empty shape signature. - return luci::ShapeSignature(); -} - -ShapeSignature reduced_signature(const loco::Node *node, const loco::Node *indices, bool keep_dims) -{ - LOGGER(l); - - ShapeSignature input_signature; - ShapeSignature output_signature; - - auto circle_node = loco::must_cast<const luci::CircleNode *>(node); - if (circle_node->shape_signature().rank() > 0) - input_signature = circle_node->shape_signature(); - else - { - input_signature.rank(circle_node->rank()); - for (uint32_t i = 0; i < circle_node->rank(); ++i) - input_signature.dim(i) = circle_node->dim(i).value(); - } - - // If input rank is 0, it means that one of following case is occurred. - // - Input is scalar : result is always scalar - // - Input shape signature is not inferenced : cannot infer output shape signauture - // Therefore, when input signature rank is 0, always return empty signature. - if (input_signature.rank() == 0) - return output_signature; - - // When reduction_indices is not constant - auto reduction_indices = dynamic_cast<const luci::CircleConst *>(indices); - if (reduction_indices == nullptr) - { - if (keep_dims) - { - // If keep_dims is true, rank is not changed. - output_signature.rank(input_signature.rank()); - for (uint32_t i = 0; i < output_signature.rank(); ++i) - output_signature.dim(i) = -1; - } - else - { - // There is no way to inference for this case. - // Do nothing to return empty signature. - INFO(l) << "[CircleShapeSignatureInferenceHelper] " << circle_node->name() << std::endl; - INFO(l) << " reduced_signature : cannot infer because of non-constant node" << std::endl; - } - - return output_signature; - } - - std::vector<int32_t> reduction_values; - if (reduction_indices->dtype() == loco::DataType::S32) - { - auto reduction_size = reduction_indices->size<loco::DataType::S32>(); - for (uint32_t i = 0; i < reduction_size; ++i) - { - int32_t axis = reduction_indices->at<loco::DataType::S32>(i); - if (axis < 0) - axis += input_signature.rank(); - - if (!(0 <= axis && axis < static_cast<int32_t>(input_signature.rank()))) - INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis)); - - reduction_values.push_back(axis); - } - } - else if (reduction_indices->dtype() == loco::DataType::S64) - { - auto reduction_size = reduction_indices->size<loco::DataType::S64>(); - for (uint32_t i = 0; i < reduction_size; ++i) - { - int32_t axis = static_cast<int32_t>(reduction_indices->at<loco::DataType::S64>(i)); - if (axis < 0) - axis += input_signature.rank(); - - if (!(0 <= axis && axis < static_cast<int32_t>(input_signature.rank()))) - INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis)); - - reduction_values.push_back(axis); - } - } - else - { - INTERNAL_EXN("Wrong reduction axis type, Only INT32, INT64 supported."); - } - - if (keep_dims) - { - output_signature.rank(input_signature.rank()); - for (uint32_t i = 0; i < input_signature.rank(); ++i) - output_signature.dim(i) = input_signature.dim(i); - for (uint32_t i = 0; i < reduction_values.size(); ++i) - output_signature.dim(reduction_values.at(i)) = 1; - } - else - { - std::vector<bool> check_reduce(input_signature.rank(), false); - for (uint32_t i = 0; i < reduction_values.size(); ++i) - check_reduce.at(reduction_values.at(i)) = true; - - uint32_t reduce_cnt = 0; - for (uint32_t i = 0; i < check_reduce.size(); ++i) - if (check_reduce.at(i)) - ++reduce_cnt; - - output_signature.rank(input_signature.rank() - reduce_cnt); - for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i) - if (check_reduce.at(i) == false) - output_signature.dim(j++) = input_signature.dim(i); - } - - return output_signature; -} - -ShapeSignature input_arg_signature(const luci::CircleNode *node, uint32_t index) -{ - auto circle_input = loco::must_cast<luci::CircleNode *>(node->arg(index)); - return circle_input->shape_signature(); -} - -} // namespace ssinf - -} // namespace luci diff --git a/compiler/luci/service/src/CircleTypeInference.cpp b/compiler/luci/service/src/CircleTypeInference.cpp index b4755b51a..db9a37cb0 100644 --- a/compiler/luci/service/src/CircleTypeInference.cpp +++ b/compiler/luci/service/src/CircleTypeInference.cpp @@ -15,72 +15,23 @@ */ #include "luci/Service/CircleTypeInference.h" +#include "CircleTypeInferenceHelper.h" #include <luci/Log.h> #include <loco.h> -#include <loco/Service/TypeInference.h> - -#include <mio/circle/schema_generated.h> -#include <oops/InternalExn.h> #include <type_traits> namespace { -circle::TensorType translateLocoTypeToCircle(loco::DataType dtype) -{ - switch (dtype) - { - case loco::DataType::U8: - return circle::TensorType_UINT8; - // case loco::DataType::U16: unsupported - // case loco::DataType::U32: unsupported - // case loco::DataType::U64: unsupported - case loco::DataType::S8: - return circle::TensorType_INT8; - case loco::DataType::S16: - return circle::TensorType_INT16; - case loco::DataType::S32: - return circle::TensorType_INT32; - case loco::DataType::S64: - return circle::TensorType_INT64; - case loco::DataType::FLOAT16: - return circle::TensorType_FLOAT16; - case loco::DataType::FLOAT32: - return circle::TensorType_FLOAT32; - // case loco::DataType::FLOAT64: unsupported - case loco::DataType::BOOL: - return circle::TensorType_BOOL; - default: - break; - } - - INTERNAL_EXN_V("Invalid loco dtype", oops::to_uint32(dtype)); -} - -} // namespace - -namespace luci -{ - -circle::TensorType TypeInference::get(loco::Node *node) -{ - assert(loco::dtype_known(node)); - return translateLocoTypeToCircle(loco::dtype_get(node)); -} - -} // namespace luci - -namespace -{ - bool inputs_dtype_ready(const luci::CircleNode *node) { for (uint32_t arity = 0; arity < node->arity(); ++arity) { - if (node->dtype() == loco::DataType::Unknown) + auto input_node = loco::must_cast<luci::CircleNode *>(node->arg(arity)); + if (input_node->dtype() == loco::DataType::Unknown) return false; } diff --git a/compiler/luci/service/src/CircleTypeInferenceHelper.cpp b/compiler/luci/service/src/CircleTypeInferenceHelper.cpp index 75cd9f7b2..06edd70f2 100644 --- a/compiler/luci/service/src/CircleTypeInferenceHelper.cpp +++ b/compiler/luci/service/src/CircleTypeInferenceHelper.cpp @@ -14,7 +14,23 @@ * limitations under the License. */ -#include "luci/Service/CircleTypeInferenceHelper.h" +#include "CircleTypeInferenceHelper.h" + +namespace luci +{ + +loco::DataType dtype_get(const loco::Node *node) +{ + assert(luci::dtype_known(node)); + return loco::must_cast<const luci::CircleNode *>(node)->dtype(); +} + +bool dtype_known(const loco::Node *node) +{ + return loco::must_cast<const luci::CircleNode *>(node)->dtype() != loco::DataType::Unknown; +} + +} // namespace luci namespace luci { diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h b/compiler/luci/service/src/CircleTypeInferenceHelper.h index 296f99355..751340cc7 100644 --- a/compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h +++ b/compiler/luci/service/src/CircleTypeInferenceHelper.h @@ -23,6 +23,20 @@ namespace luci { + +// NOTE Functions in this namespace will be removed after new inference +// algorithms are fully implemented. + +// This function is temporary function for deprecating loco::dtype_get +loco::DataType dtype_get(const loco::Node *node); + +// This function is temporary function for deprecating loco::dtype_known +bool dtype_known(const loco::Node *node); + +} // namespace luci + +namespace luci +{ namespace tinf // Namespace for Type Inference { diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.cpp index f738ab5a8..0b8d2af9e 100644 --- a/compiler/luci/service/src/CircleTypeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleTypeInferenceRule.cpp @@ -15,6 +15,7 @@ */ #include "luci/Service/CircleTypeInferenceRule.h" +#include "CircleTypeInferenceHelper.h" #include <luci/IR/CircleDialect.h> #include <luci/IR/CircleNodeVisitor.h> @@ -29,24 +30,24 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT { // TODO Given a tensor x of complex numbers, Abs operation returns a tensor of type float32 or // float64. - loco::DataType visit(const luci::CircleAbs *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleAbs *node) final { return luci::dtype_get(node->x()); } - loco::DataType visit(const luci::CircleAdd *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleAdd *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleAddN *node) final { - auto dtype = loco::dtype_get(node->inputs(0)); + auto dtype = luci::dtype_get(node->inputs(0)); for (uint32_t idx = 1; idx < node->arity(); ++idx) { - auto dtype_idx = loco::dtype_get(node->inputs(idx)); + auto dtype_idx = luci::dtype_get(node->inputs(idx)); if (dtype != dtype_idx) { INTERNAL_EXN_V("ADD_N dtype not same as the first input: ", idx); } } - return loco::dtype_get(node->inputs(0)); + return luci::dtype_get(node->inputs(0)); } loco::DataType visit(const luci::CircleArgMax *node) final { return node->output_type(); } @@ -55,22 +56,22 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT loco::DataType visit(const luci::CircleAveragePool2D *node) final { - return loco::dtype_get(node->value()); + return luci::dtype_get(node->value()); } loco::DataType visit(const luci::CircleBatchMatMul *node) final { - return loco::dtype_get(node->x()); + return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleBatchToSpaceND *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleCast *node) final { return node->dtype(); } - loco::DataType visit(const luci::CircleCeil *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleCeil *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleConcatenation *node) final { @@ -78,87 +79,92 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT assert(node->numValues() > 0); for (uint32_t i = 1; i < node->numValues(); ++i) - assert(loco::dtype_get(node->values(i - 1)) == loco::dtype_get(node->values(i))); + assert(luci::dtype_get(node->values(i - 1)) == luci::dtype_get(node->values(i))); - return loco::dtype_get(node->values(0)); + return luci::dtype_get(node->values(0)); } loco::DataType visit(const luci::CircleConst *node) final { return node->dtype(); } loco::DataType visit(const luci::CircleConv2D *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } - loco::DataType visit(const luci::CircleCos *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleCos *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleCustom *node) final { if (node->custom_code() == "BatchMatMulV2") { - return loco::dtype_get(node->inputs(0)); + return luci::dtype_get(node->inputs(0)); } return node->dtype(); } loco::DataType visit(const luci::CircleDepthToSpace *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleDepthwiseConv2D *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleDequantize *) final { return loco::DataType::FLOAT32; } - loco::DataType visit(const luci::CircleDiv *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleDiv *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleElu *node) final { - return loco::dtype_get(node->features()); + return luci::dtype_get(node->features()); } loco::DataType visit(const luci::CircleEqual *) final { return loco::DataType::BOOL; } - loco::DataType visit(const luci::CircleExp *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleExp *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleExpandDims *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); + } + + loco::DataType visit(const luci::CircleFakeQuant *node) final + { + return luci::dtype_get(node->inputs()); } loco::DataType visit(const luci::CircleFill *node) final { - return loco::dtype_get(node->value()); + return luci::dtype_get(node->value()); } - loco::DataType visit(const luci::CircleFloor *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleFloor *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleFloorDiv *node) final { - return loco::dtype_get(node->x()); + return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleFloorMod *node) final { - return loco::dtype_get(node->x()); + return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleFullyConnected *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleGather *node) final { - return loco::dtype_get(node->params()); + return luci::dtype_get(node->params()); } loco::DataType visit(const luci::CircleGatherNd *node) final { - return loco::dtype_get(node->params()); + return luci::dtype_get(node->params()); } loco::DataType visit(const luci::CircleGreater *) final { return loco::DataType::BOOL; } @@ -169,22 +175,22 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT { // Type of If is not used. Just use input 0 assert(node->input_count() > 0); - return loco::dtype_get(node->input(0)); + return luci::dtype_get(node->input(0)); } loco::DataType visit(const luci::CircleL2Normalize *node) final { - return loco::dtype_get(node->x()); + return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleL2Pool2D *node) final { - return loco::dtype_get(node->value()); + return luci::dtype_get(node->value()); } loco::DataType visit(const luci::CircleLeakyRelu *node) final { - return loco::dtype_get(node->features()); + return luci::dtype_get(node->features()); } loco::DataType visit(const luci::CircleLess *) final { return loco::DataType::BOOL; } @@ -193,75 +199,75 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT loco::DataType visit(const luci::CircleLocalResponseNormalization *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } - loco::DataType visit(const luci::CircleLog *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleLog *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleLogicalAnd *node) final { - return loco::dtype_get(node->x()); + return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleLogicalNot *node) final { - return loco::dtype_get(node->x()); + return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleLogicalOr *node) final { - return loco::dtype_get(node->x()); + return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleLogistic *node) final { - return loco::dtype_get(node->x()); + return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleLogSoftmax *node) final { - return loco::dtype_get(node->logits()); + return luci::dtype_get(node->logits()); } loco::DataType visit(const luci::CircleMatrixDiag *node) final { - return loco::dtype_get(node->diagonal()); + return luci::dtype_get(node->diagonal()); } loco::DataType visit(const luci::CircleMatrixSetDiag *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } - loco::DataType visit(const luci::CircleMaximum *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleMaximum *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleMaxPool2D *node) final { - return loco::dtype_get(node->value()); + return luci::dtype_get(node->value()); } loco::DataType visit(const luci::CircleMean *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } - loco::DataType visit(const luci::CircleMinimum *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleMinimum *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleMirrorPad *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } - loco::DataType visit(const luci::CircleNeg *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleNeg *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleNonMaxSuppressionV4 *node) final { - return loco::dtype_get(node->boxes()); + return luci::dtype_get(node->boxes()); } loco::DataType visit(const luci::CircleNonMaxSuppressionV5 *node) final { - return loco::dtype_get(node->boxes()); + return luci::dtype_get(node->boxes()); } loco::DataType visit(const luci::CircleNotEqual *) final { return loco::DataType::BOOL; } @@ -271,25 +277,25 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT // Only support CirclePack with one or more inputs assert(node->values_count() > 0); - auto first_value_type = loco::dtype_get(node->values(0)); + auto first_value_type = luci::dtype_get(node->values(0)); for (uint32_t i = 1; i < node->values_count(); ++i) - assert(first_value_type == loco::dtype_get(node->values(i))); + assert(first_value_type == luci::dtype_get(node->values(i))); return first_value_type; } - loco::DataType visit(const luci::CirclePad *node) final { return loco::dtype_get(node->input()); } + loco::DataType visit(const luci::CirclePad *node) final { return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CirclePadV2 *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CirclePow *node) final { // TODO make sure types cannot differ - auto x_type = loco::dtype_get(node->x()); - auto y_type = loco::dtype_get(node->y()); + auto x_type = luci::dtype_get(node->x()); + auto y_type = luci::dtype_get(node->y()); if (x_type != y_type) INTERNAL_EXN("Different datatype for x and y are not supported"); @@ -299,8 +305,8 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT loco::DataType visit(const luci::CirclePRelu *node) final { - auto input_type = loco::dtype_get(node->input()); - auto alpha_type = loco::dtype_get(node->alpha()); + auto input_type = luci::dtype_get(node->input()); + auto alpha_type = luci::dtype_get(node->alpha()); if (input_type != alpha_type) INTERNAL_EXN("Different datatype for input and alpha are not supported"); @@ -310,201 +316,201 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT loco::DataType visit(const luci::CircleRange *node) final { - return loco::dtype_get(node->start()); + return luci::dtype_get(node->start()); } loco::DataType visit(const luci::CircleRank *) final { return loco::DataType::S32; } - loco::DataType visit(const luci::CircleMul *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleMul *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleOneHot *node) final { - return loco::dtype_get(node->on_value()); + return luci::dtype_get(node->on_value()); } loco::DataType visit(const luci::CircleReduceAny *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleReduceMax *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleReduceMin *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleReduceProd *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleRelu *node) final { - return loco::dtype_get(node->features()); + return luci::dtype_get(node->features()); } loco::DataType visit(const luci::CircleRelu6 *node) final { - return loco::dtype_get(node->features()); + return luci::dtype_get(node->features()); } loco::DataType visit(const luci::CircleReluN1To1 *node) final { - return loco::dtype_get(node->features()); + return luci::dtype_get(node->features()); } loco::DataType visit(const luci::CircleReshape *node) final { - return loco::dtype_get(node->tensor()); + return luci::dtype_get(node->tensor()); } loco::DataType visit(const luci::CircleResizeBilinear *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleResizeNearestNeighbor *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleReverseSequence *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleReverseV2 *node) final { - return loco::dtype_get(node->tensor()); + return luci::dtype_get(node->tensor()); } - loco::DataType visit(const luci::CircleRound *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleRound *node) final { return luci::dtype_get(node->x()); } - loco::DataType visit(const luci::CircleRsqrt *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleRsqrt *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleScatterNd *node) final { - return loco::dtype_get(node->updates()); + return luci::dtype_get(node->updates()); } loco::DataType visit(const luci::CircleSegmentSum *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleSelect *node) final { - assert(loco::dtype_get(node->t()) == loco::dtype_get(node->e())); - return loco::dtype_get(node->t()); + assert(luci::dtype_get(node->t()) == luci::dtype_get(node->e())); + return luci::dtype_get(node->t()); } loco::DataType visit(const luci::CircleSelectV2 *node) final { - assert(loco::dtype_get(node->t()) == loco::dtype_get(node->e())); - return loco::dtype_get(node->t()); + assert(luci::dtype_get(node->t()) == luci::dtype_get(node->e())); + return luci::dtype_get(node->t()); } loco::DataType visit(const luci::CircleShape *node) final { return node->out_type(); } - loco::DataType visit(const luci::CircleSin *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleSin *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleSlice *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleSoftmax *node) final { - return loco::dtype_get(node->logits()); + return luci::dtype_get(node->logits()); } loco::DataType visit(const luci::CircleSpaceToBatchND *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleSpaceToDepth *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleSparseToDense *node) final { - return loco::dtype_get(node->values()); + return luci::dtype_get(node->values()); } loco::DataType visit(const luci::CircleSplit *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleSplitV *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } - loco::DataType visit(const luci::CircleSqrt *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleSqrt *node) final { return luci::dtype_get(node->x()); } - loco::DataType visit(const luci::CircleSquare *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleSquare *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleSquaredDifference *node) final { - return loco::dtype_get(node->x()); + return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleSqueeze *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleStridedSlice *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } - loco::DataType visit(const luci::CircleSub *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleSub *node) final { return luci::dtype_get(node->x()); } - loco::DataType visit(const luci::CircleSum *node) final { return loco::dtype_get(node->input()); } + loco::DataType visit(const luci::CircleSum *node) final { return luci::dtype_get(node->input()); } - loco::DataType visit(const luci::CircleTanh *node) final { return loco::dtype_get(node->x()); } + loco::DataType visit(const luci::CircleTanh *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleTile *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleTopKV2 *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleTranspose *node) final { - return loco::dtype_get(node->a()); + return luci::dtype_get(node->a()); } loco::DataType visit(const luci::CircleTransposeConv *node) final { - return loco::dtype_get(node->outBackprop()); + return luci::dtype_get(node->outBackprop()); } loco::DataType visit(const luci::CircleUnidirectionalSequenceLSTM *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleUnique *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleUnpack *node) final { - return loco::dtype_get(node->value()); + return luci::dtype_get(node->value()); } loco::DataType visit(const luci::CircleWhere *) final { return loco::DataType::S64; } @@ -513,12 +519,12 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT { // Type of While is not used. Just use input 0 assert(node->input_count() > 0); - return loco::dtype_get(node->input(0)); + return luci::dtype_get(node->input(0)); } loco::DataType visit(const luci::CircleZerosLike *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } // Circle Only @@ -531,7 +537,7 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT loco::DataType visit(const luci::CircleInstanceNorm *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } // Virtual @@ -548,7 +554,7 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT { // We don't care for the type if from() is CircleOutputDummy or CircleOutputExclude // from() type should match that of CircleOutput - assert(output_dtype == loco::dtype_get(node->from())); + assert(output_dtype == luci::dtype_get(node->from())); } return output_dtype; } @@ -559,46 +565,6 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT loco::DataType visit(const luci::CircleCustomOut *node) final { return node->dtype(); } - loco::DataType visit(const luci::CircleIfOut *node) final - { - /** - * @note IF operator type and shape are that of the "then" and "else" - * Graph Outputs. - */ - auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input()); - if (circle_if == nullptr) - { - INTERNAL_EXN("CircleIf IR is not configured correctly"); - } - - auto index = node->index(); - auto then_graph = circle_if->then_graph(); - auto else_graph = circle_if->else_graph(); - assert(then_graph != nullptr); - assert(else_graph != nullptr); - - // shape and type are assumed to be same - // these are checked at post_import_graph() in Import - auto then_outputs = loco::output_nodes(then_graph); - auto else_outputs = loco::output_nodes(else_graph); - assert(then_outputs.size() == else_outputs.size()); - assert(index < static_cast<int32_t>(then_outputs.size())); - - auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index)); - auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index)); - - auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items - auto else_graph_outputs = else_graph->outputs(); - assert(then_graph_outputs->size() == else_graph_outputs->size()); - - auto then_graph_output = then_graph_outputs->at(then_out->index()); - auto else_graph_output = else_graph_outputs->at(else_out->index()); - (void)else_graph_output; // make compiler happy for unused variable warnings - assert(then_graph_output->dtype() == else_graph_output->dtype()); - - return then_graph_output->dtype(); - } - loco::DataType visit(const luci::CircleNonMaxSuppressionV4Out *node) final { (void)node; @@ -619,19 +585,19 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT loco::DataType visit(const luci::CircleSplitOut *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleSplitVOut *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleTopKV2Out *node) final { // First output is same as input if (node->index() == 0) - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); // Second outout is always S32 assert(node->index() == 1); return loco::DataType::S32; @@ -641,7 +607,7 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT { if (node->index() == 0) { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } assert(node->index() == 1); auto unique = loco::must_cast<luci::CircleUnique *>(node->input()); @@ -650,7 +616,7 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT loco::DataType visit(const luci::CircleUnpackOut *node) final { - return loco::dtype_get(node->input()); + return luci::dtype_get(node->input()); } loco::DataType visit(const luci::CircleWhileOut *node) final diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.test.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.test.cpp deleted file mode 100644 index 711a489af..000000000 --- a/compiler/luci/service/src/CircleTypeInferenceRule.test.cpp +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "TestGraph.h" -#include <luci/Service/CircleTypeInferenceRule.h> - -#include <luci/IR/CircleNodes.h> -#include <luci/IR/CircleDialect.h> - -#include <loco.h> -#include <loco/IR/CanonicalDialect.h> -#include <loco/Service/TypeInference.h> - -#include <gtest/gtest.h> - -#include <memory> - -TEST(CircleTypeInferenceRuleTest, minimal_with_CircleRelu) -{ - // Create a simple network - luci::test::TestGraph graph; - auto relu_node = graph.append<luci::CircleRelu>(graph.input_node); - graph.complete(relu_node); - - // set dtype for nodes; like setting them in import - graph.input_node->dtype(loco::DataType::S32); - relu_node->dtype(loco::DataType::S32); - graph.output_node->dtype(loco::DataType::S32); - - luci::test::graph_input_dtype(graph.input_node); - luci::test::graph_output_dtype(graph.output_node); - - // pre-check - ASSERT_FALSE(loco::dtype_known(relu_node)); - - // type inference - luci::CircleTypeInferenceRule circle_rule; - loco::CanonicalTypeInferenceRule canon_rule; - loco::MultiDialectTypeInferenceRule rules; - - rules.bind(loco::CanonicalDialect::get(), &canon_rule); - rules.bind(luci::CircleDialect::get(), &circle_rule); - - loco::apply(&rules).to(graph.g.get()); - - // Verify - ASSERT_TRUE(loco::dtype_known(relu_node)); - auto type = loco::dtype_get(relu_node); - ASSERT_EQ(loco::DataType::S32, type); -} diff --git a/compiler/luci/service/src/Nodes/CircleAbs.cpp b/compiler/luci/service/src/Nodes/CircleAbs.cpp new file mode 100644 index 000000000..132760957 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleAbs.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleAbs *) +{ + return _graph->nodes()->create<luci::CircleAbs>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleAbs.test.cpp b/compiler/luci/service/src/Nodes/CircleAbs.test.cpp new file mode 100644 index 000000000..885b395b8 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleAbs.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Abs) +{ + auto g = loco::make_graph(); + auto node_abs = g->nodes()->create<luci::CircleAbs>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_abs, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_abs = dynamic_cast<luci::CircleAbs *>(cloned); + ASSERT_NE(nullptr, cloned_abs); +} diff --git a/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h b/compiler/luci/service/src/Nodes/CircleAdd.cpp index 9d964bdd6..08634320e 100644 --- a/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h +++ b/compiler/luci/service/src/Nodes/CircleAdd.cpp @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,29 +14,20 @@ * limitations under the License. */ -#ifndef __LUCI_TYPE_INFERENCE_PASS_H__ -#define __LUCI_TYPE_INFERENCE_PASS_H__ - -#include <loco.h> - -#include <luci/ModulePass.h> +#include "CircleCloneNode.h" namespace luci { -/** - * @brief Pass to infer type of nodes - */ -class TypeInferencePass : public luci::Pass +luci::CircleNode *CloneNode::visit(const luci::CircleAdd *node) { -public: - virtual const char *name(void) const { return "luci::TypeInferencePass"; } + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; -public: - bool run(luci::Module *m); - bool run(loco::Graph *graph); -}; + auto *cloned = _graph->nodes()->create<luci::CircleAdd>(); + if (cloned != nullptr) + cloned->fusedActivationFunction(node->fusedActivationFunction()); + return cloned; +} } // namespace luci - -#endif //__LUCI_TYPE_INFERENCE_PASS_H__ diff --git a/compiler/luci/service/src/Nodes/CircleAdd.test.cpp b/compiler/luci/service/src/Nodes/CircleAdd.test.cpp new file mode 100644 index 000000000..41a818b0a --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleAdd.test.cpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Service/CircleShapeInference.h> + +#include <loco/IR/TensorShape.h> + +#include <gtest/gtest.h> + +/** + * @note Function to test: Shape inference of two different input shapes + * + * Rank expansion to higher input side + * x(2,1,5) + y(3,5) --> x(2,1,5) + y(1,3,5) + * Do output shape inference like numpy + * x(2,1,5) + y(1,3,5) --> output(2,3,5) + * For each axis, dim value should be same OR one of them should be 1 + */ +TEST(ShapeRuleTest, different_input_shapes_add) +{ + luci::CircleInput input1; + luci::CircleInput input2; + luci::CircleAdd add; + + input1.shape({2, 1, 5}); + input1.shape_status(luci::ShapeStatus::VALID); + input2.shape({3, 5}); + input2.shape_status(luci::ShapeStatus::VALID); + + add.x(&input1); + add.y(&input2); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&add, shape)); + ASSERT_EQ(3, shape.rank()); + ASSERT_EQ(2, shape.dim(0).value()); + ASSERT_EQ(3, shape.dim(1).value()); + ASSERT_EQ(5, shape.dim(2).value()); +} + +TEST(CloneNodeTest, clone_Add) +{ + auto g = loco::make_graph(); + auto node_add = g->nodes()->create<luci::CircleAdd>(); + node_add->fusedActivationFunction(luci::FusedActFunc::RELU); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_add, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_add = dynamic_cast<luci::CircleAdd *>(cloned); + ASSERT_NE(nullptr, cloned_add); + ASSERT_EQ(node_add->fusedActivationFunction(), cloned_add->fusedActivationFunction()); +} + +TEST(CloneNodeTest, clone_Add_NEG) +{ + auto g = loco::make_graph(); + auto node_add = g->nodes()->create<luci::CircleAdd>(); + node_add->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_add, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleAddN.cpp b/compiler/luci/service/src/Nodes/CircleAddN.cpp new file mode 100644 index 000000000..e536e54bb --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleAddN.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleAddN *node) +{ + auto arity = node->arity(); + return _graph->nodes()->create<luci::CircleAddN>(arity); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleAddN.test.cpp b/compiler/luci/service/src/Nodes/CircleAddN.test.cpp new file mode 100644 index 000000000..5d5b82247 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleAddN.test.cpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_AddN) +{ + auto g = loco::make_graph(); + auto node_addn = g->nodes()->create<luci::CircleAddN>(3); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_addn, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_addn = dynamic_cast<luci::CircleAddN *>(cloned); + ASSERT_NE(nullptr, cloned_addn); + ASSERT_EQ(node_addn->arity(), cloned_addn->arity()); +} diff --git a/compiler/luci/service/src/Nodes/CircleArgMax.cpp b/compiler/luci/service/src/Nodes/CircleArgMax.cpp new file mode 100644 index 000000000..1b3bafa86 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleArgMax.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleArgMax *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleArgMax>(); + if (cloned != nullptr) + cloned->output_type(node->output_type()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleArgMax.test.cpp b/compiler/luci/service/src/Nodes/CircleArgMax.test.cpp new file mode 100644 index 000000000..bb7588403 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleArgMax.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_ArgMax) +{ + auto g = loco::make_graph(); + auto node_argmax = g->nodes()->create<luci::CircleArgMax>(); + node_argmax->output_type(loco::DataType::FLOAT32); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_argmax, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_argmax = dynamic_cast<luci::CircleArgMax *>(cloned); + ASSERT_NE(nullptr, cloned_argmax); + ASSERT_EQ(node_argmax->output_type(), cloned_argmax->output_type()); +} diff --git a/compiler/luci/service/src/Nodes/CircleArgMin.cpp b/compiler/luci/service/src/Nodes/CircleArgMin.cpp new file mode 100644 index 000000000..fa54f7b76 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleArgMin.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleArgMin *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleArgMin>(); + if (cloned != nullptr) + cloned->output_type(node->output_type()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleArgMin.test.cpp b/compiler/luci/service/src/Nodes/CircleArgMin.test.cpp new file mode 100644 index 000000000..ca57946f9 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleArgMin.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_ArgMin) +{ + auto g = loco::make_graph(); + auto node_argmin = g->nodes()->create<luci::CircleArgMin>(); + node_argmin->output_type(loco::DataType::FLOAT32); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_argmin, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_argmin = dynamic_cast<luci::CircleArgMin *>(cloned); + ASSERT_NE(nullptr, cloned_argmin); + ASSERT_EQ(node_argmin->output_type(), cloned_argmin->output_type()); +} diff --git a/compiler/luci/service/src/Nodes/CircleAveragePool2D.cpp b/compiler/luci/service/src/Nodes/CircleAveragePool2D.cpp new file mode 100644 index 000000000..4d2791833 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleAveragePool2D.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleAveragePool2D *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + if (node->padding() == luci::Padding::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleAveragePool2D>(); + if (cloned != nullptr) + { + cloned->fusedActivationFunction(node->fusedActivationFunction()); + cloned->padding(node->padding()); + cloned->filter()->h(node->filter()->h()); + cloned->filter()->w(node->filter()->w()); + cloned->stride()->h(node->stride()->h()); + cloned->stride()->w(node->stride()->w()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleAveragePool2D.test.cpp b/compiler/luci/service/src/Nodes/CircleAveragePool2D.test.cpp new file mode 100644 index 000000000..d048d1426 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleAveragePool2D.test.cpp @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Service/CircleShapeInference.h> + +#include <loco/IR/TensorShape.h> + +#include <gtest/gtest.h> + +TEST(ShapeRuleTest, simple_valid_pad_avgpool2d) +{ + luci::CircleInput input; + luci::CircleAveragePool2D avgpool_2d; + + input.shape({1, 4, 3, 1}); + input.shape_status(luci::ShapeStatus::VALID); + + avgpool_2d.value(&input); + avgpool_2d.filter()->h(2); + avgpool_2d.filter()->w(2); + avgpool_2d.stride()->h(2); + avgpool_2d.stride()->w(2); + avgpool_2d.fusedActivationFunction(luci::FusedActFunc::NONE); + avgpool_2d.padding(luci::Padding::VALID); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&avgpool_2d, shape)); + ASSERT_EQ(4, shape.rank()); + ASSERT_EQ(1, shape.dim(0).value()); + ASSERT_EQ(2, shape.dim(1).value()); + ASSERT_EQ(1, shape.dim(2).value()); + ASSERT_EQ(1, shape.dim(3).value()); +} + +TEST(ShapeRuleTest, simple_same_pad_avgpool2d) +{ + luci::CircleInput input; + luci::CircleAveragePool2D avgpool_2d; + + input.shape({1, 4, 3, 1}); + input.shape_status(luci::ShapeStatus::VALID); + + avgpool_2d.value(&input); + avgpool_2d.filter()->h(2); + avgpool_2d.filter()->w(2); + avgpool_2d.stride()->h(2); + avgpool_2d.stride()->w(2); + avgpool_2d.fusedActivationFunction(luci::FusedActFunc::NONE); + avgpool_2d.padding(luci::Padding::SAME); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&avgpool_2d, shape)); + ASSERT_EQ(4, shape.rank()); + ASSERT_EQ(1, shape.dim(0).value()); + ASSERT_EQ(2, shape.dim(1).value()); + ASSERT_EQ(2, shape.dim(2).value()); + ASSERT_EQ(1, shape.dim(3).value()); +} + +TEST(CloneNodeTest, clone_AveragePool2D) +{ + auto g = loco::make_graph(); + auto node_avgpool2d = g->nodes()->create<luci::CircleAveragePool2D>(); + node_avgpool2d->fusedActivationFunction(luci::FusedActFunc::RELU); + node_avgpool2d->padding(luci::Padding::SAME); + node_avgpool2d->filter()->h(1); + node_avgpool2d->filter()->w(2); + node_avgpool2d->stride()->h(3); + node_avgpool2d->stride()->w(4); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_avgpool2d, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_avgpool2d = dynamic_cast<luci::CircleAveragePool2D *>(cloned); + ASSERT_NE(nullptr, cloned_avgpool2d); + ASSERT_EQ(node_avgpool2d->fusedActivationFunction(), cloned_avgpool2d->fusedActivationFunction()); + ASSERT_EQ(node_avgpool2d->padding(), cloned_avgpool2d->padding()); + ASSERT_EQ(node_avgpool2d->filter()->h(), cloned_avgpool2d->filter()->h()); + ASSERT_EQ(node_avgpool2d->filter()->w(), cloned_avgpool2d->filter()->w()); + ASSERT_EQ(node_avgpool2d->stride()->h(), cloned_avgpool2d->stride()->h()); + ASSERT_EQ(node_avgpool2d->stride()->w(), cloned_avgpool2d->stride()->w()); +} + +TEST(CloneNodeTest, clone_AveragePool2D_fusedact_NEG) +{ + auto g = loco::make_graph(); + auto node_avgpool2d = g->nodes()->create<luci::CircleAveragePool2D>(); + node_avgpool2d->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + node_avgpool2d->padding(luci::Padding::SAME); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_avgpool2d, gc.get()); + ASSERT_EQ(nullptr, cloned); +} + +TEST(CloneNodeTest, clone_AveragePool2D_padding_NEG) +{ + auto g = loco::make_graph(); + auto node_avgpool2d = g->nodes()->create<luci::CircleAveragePool2D>(); + node_avgpool2d->fusedActivationFunction(luci::FusedActFunc::RELU); + node_avgpool2d->padding(luci::Padding::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_avgpool2d, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h b/compiler/luci/service/src/Nodes/CircleBCQFullyConnected.cpp index 2c6ffcf4e..3edc06ab8 100644 --- a/compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h +++ b/compiler/luci/service/src/Nodes/CircleBCQFullyConnected.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,29 +14,23 @@ * limitations under the License. */ -#ifndef __LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__ -#define __LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__ - -#include <loco.h> - -#include <luci/ModulePass.h> +#include "CircleCloneNode.h" namespace luci { -/** - * @brief Pass to infer shape_signature of nodes - */ -class ShapeSignatureInferencePass : public luci::Pass +luci::CircleNode *CloneNode::visit(const luci::CircleBCQFullyConnected *node) { -public: - virtual const char *name(void) const { return "luci::ShapeSignatureInferencePass"; } - -public: - bool run(luci::Module *m); - bool run(loco::Graph *graph); -}; + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleBCQFullyConnected>(); + if (cloned != nullptr) + { + cloned->fusedActivationFunction(node->fusedActivationFunction()); + cloned->weights_hidden_size(node->weights_hidden_size()); + } + return cloned; +} } // namespace luci - -#endif //__LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__ diff --git a/compiler/luci/service/src/Nodes/CircleBCQFullyConnected.test.cpp b/compiler/luci/service/src/Nodes/CircleBCQFullyConnected.test.cpp new file mode 100644 index 000000000..90c192e07 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleBCQFullyConnected.test.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_BCQFullyConnected) +{ + auto g = loco::make_graph(); + auto node_fc = g->nodes()->create<luci::CircleBCQFullyConnected>(); + node_fc->fusedActivationFunction(luci::FusedActFunc::RELU); + node_fc->weights_hidden_size(3); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_fc, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_fc = dynamic_cast<luci::CircleBCQFullyConnected *>(cloned); + ASSERT_NE(nullptr, cloned_fc); + ASSERT_EQ(node_fc->fusedActivationFunction(), cloned_fc->fusedActivationFunction()); + ASSERT_EQ(node_fc->weights_hidden_size(), cloned_fc->weights_hidden_size()); +} + +TEST(CloneNodeTest, clone_BCQFullyConnected_fusedact_NEG) +{ + auto g = loco::make_graph(); + auto node_fc = g->nodes()->create<luci::CircleBCQFullyConnected>(); + node_fc->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_fc, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleBCQGather.cpp b/compiler/luci/service/src/Nodes/CircleBCQGather.cpp new file mode 100644 index 000000000..35b6be744 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleBCQGather.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleBCQGather *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleBCQGather>(); + if (cloned != nullptr) + { + cloned->axis(node->axis()); + cloned->input_hidden_size(node->input_hidden_size()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleBCQGather.test.cpp b/compiler/luci/service/src/Nodes/CircleBCQGather.test.cpp new file mode 100644 index 000000000..a3f9e8850 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleBCQGather.test.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_BCQGather) +{ + auto g = loco::make_graph(); + auto node_gat = g->nodes()->create<luci::CircleBCQGather>(); + node_gat->axis(3); + node_gat->input_hidden_size(5); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_gat, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_gat = dynamic_cast<luci::CircleBCQGather *>(cloned); + ASSERT_NE(nullptr, cloned_gat); + ASSERT_EQ(node_gat->axis(), cloned_gat->axis()); + ASSERT_EQ(node_gat->input_hidden_size(), cloned_gat->input_hidden_size()); +} diff --git a/compiler/luci/service/src/Nodes/CircleBatchMatMul.cpp b/compiler/luci/service/src/Nodes/CircleBatchMatMul.cpp new file mode 100644 index 000000000..c7a8bbd52 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleBatchMatMul.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleBatchMatMul *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleBatchMatMul>(); + if (cloned != nullptr) + { + cloned->adj_x(node->adj_x()); + cloned->adj_y(node->adj_y()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleBatchMatMul.test.cpp b/compiler/luci/service/src/Nodes/CircleBatchMatMul.test.cpp new file mode 100644 index 000000000..e013feae8 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleBatchMatMul.test.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_BatchMatMul) +{ + auto g = loco::make_graph(); + auto node_bmm = g->nodes()->create<luci::CircleBatchMatMul>(); + node_bmm->adj_x(true); + node_bmm->adj_y(true); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_bmm, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_bmm = dynamic_cast<luci::CircleBatchMatMul *>(cloned); + ASSERT_NE(nullptr, cloned_bmm); + ASSERT_EQ(node_bmm->adj_x(), cloned_bmm->adj_x()); + ASSERT_EQ(node_bmm->adj_y(), cloned_bmm->adj_y()); +} diff --git a/compiler/luci/service/src/Nodes/CircleBatchToSpaceND.cpp b/compiler/luci/service/src/Nodes/CircleBatchToSpaceND.cpp new file mode 100644 index 000000000..70aa05f72 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleBatchToSpaceND.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleBatchToSpaceND *) +{ + return _graph->nodes()->create<luci::CircleBatchToSpaceND>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleBatchToSpaceND.test.cpp b/compiler/luci/service/src/Nodes/CircleBatchToSpaceND.test.cpp new file mode 100644 index 000000000..a45039fc7 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleBatchToSpaceND.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_BatchToSpaceND) +{ + auto g = loco::make_graph(); + auto node_b2s = g->nodes()->create<luci::CircleBatchToSpaceND>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_b2s, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_b2s = dynamic_cast<luci::CircleBatchToSpaceND *>(cloned); + ASSERT_NE(nullptr, cloned_b2s); +} diff --git a/compiler/luci/service/src/Nodes/CircleCast.cpp b/compiler/luci/service/src/Nodes/CircleCast.cpp new file mode 100644 index 000000000..75f15f9de --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleCast.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleCast *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleCast>(); + if (cloned != nullptr) + { + cloned->in_data_type(node->in_data_type()); + cloned->out_data_type(node->out_data_type()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleCast.test.cpp b/compiler/luci/service/src/Nodes/CircleCast.test.cpp new file mode 100644 index 000000000..1c4bacb73 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleCast.test.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Cast) +{ + auto g = loco::make_graph(); + auto node_cast = g->nodes()->create<luci::CircleCast>(); + node_cast->in_data_type(loco::DataType::U16); + node_cast->out_data_type(loco::DataType::S32); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_cast, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_cast = dynamic_cast<luci::CircleCast *>(cloned); + ASSERT_NE(nullptr, cloned_cast); + ASSERT_EQ(node_cast->in_data_type(), cloned_cast->in_data_type()); + ASSERT_EQ(node_cast->out_data_type(), cloned_cast->out_data_type()); +} diff --git a/compiler/luci/service/src/Nodes/CircleCeil.cpp b/compiler/luci/service/src/Nodes/CircleCeil.cpp new file mode 100644 index 000000000..92d039a7d --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleCeil.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleCeil *) +{ + return _graph->nodes()->create<luci::CircleCeil>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleCeil.test.cpp b/compiler/luci/service/src/Nodes/CircleCeil.test.cpp new file mode 100644 index 000000000..b182127d9 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleCeil.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Ceil) +{ + auto g = loco::make_graph(); + auto node_ceil = g->nodes()->create<luci::CircleCeil>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_ceil, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_ceil = dynamic_cast<luci::CircleCeil *>(cloned); + ASSERT_NE(nullptr, cloned_ceil); +} diff --git a/compiler/luci/service/src/Nodes/CircleConcatenation.cpp b/compiler/luci/service/src/Nodes/CircleConcatenation.cpp new file mode 100644 index 000000000..75d6a53e6 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleConcatenation.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleConcatenation *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleConcatenation>(node->numValues()); + if (cloned != nullptr) + { + cloned->fusedActivationFunction(node->fusedActivationFunction()); + cloned->axis(node->axis()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleConcatenation.test.cpp b/compiler/luci/service/src/Nodes/CircleConcatenation.test.cpp new file mode 100644 index 000000000..270068cf0 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleConcatenation.test.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Concatenation) +{ + auto g = loco::make_graph(); + auto node_concat = g->nodes()->create<luci::CircleConcatenation>(3); + node_concat->fusedActivationFunction(luci::FusedActFunc::RELU); + node_concat->axis(7); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_concat, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_concat = dynamic_cast<luci::CircleConcatenation *>(cloned); + ASSERT_NE(nullptr, cloned_concat); + ASSERT_EQ(node_concat->numValues(), cloned_concat->numValues()); + ASSERT_EQ(node_concat->fusedActivationFunction(), cloned_concat->fusedActivationFunction()); + ASSERT_EQ(node_concat->axis(), cloned_concat->axis()); +} + +TEST(CloneNodeTest, clone_Concatenation_NEG) +{ + auto g = loco::make_graph(); + auto node_concat = g->nodes()->create<luci::CircleConcatenation>(3); + node_concat->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_concat, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleConst.cpp b/compiler/luci/service/src/Nodes/CircleConst.cpp new file mode 100644 index 000000000..0306ef4eb --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleConst.cpp @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +#include "luci/Service/CircleNodeClone.h" + +#include <luci/IR/Nodes/CircleConst.h> + +#include <loco.h> +#include <loco/IR/Graph.h> + +#include <oops/UserExn.h> + +#include <cassert> + +namespace +{ + +template <loco::DataType T> +void copy_values(const luci::CircleConst *node, luci::CircleConst *cloned) +{ + assert(T == node->dtype()); + assert(T == cloned->dtype()); + + const auto size = node->size<T>(); + cloned->size<T>(size); + for (uint32_t i = 0; i < size; i++) + cloned->at<T>(i) = node->at<T>(i); +} + +luci::CircleConst *clone_circleconst(const luci::CircleConst *node, loco::Graph *graph) +{ + auto cloned = graph->nodes()->create<luci::CircleConst>(); + + if (cloned != nullptr) + { + // dtype/shape + cloned->dtype(node->dtype()); + cloned->rank(node->rank()); + + // values + switch (node->dtype()) + { + case loco::DataType::FLOAT32: + copy_values<loco::DataType::FLOAT32>(node, cloned); + break; + + case loco::DataType::U8: + copy_values<loco::DataType::U8>(node, cloned); + break; + + case loco::DataType::S8: + copy_values<loco::DataType::S8>(node, cloned); + break; + + case loco::DataType::S16: + copy_values<loco::DataType::S16>(node, cloned); + break; + + case loco::DataType::S32: + copy_values<loco::DataType::S32>(node, cloned); + break; + + case loco::DataType::S64: + copy_values<loco::DataType::S64>(node, cloned); + break; + + case loco::DataType::BOOL: + copy_values<loco::DataType::BOOL>(node, cloned); + break; + + default: + throw oops::UserExn("Unsupported tensor dtype"); + } + } + + return cloned; +} + +} // namespace + +namespace luci +{ + +luci::CircleConst *clone(luci::CircleConst *node) +{ + auto *cloned = clone_circleconst(node, node->graph()); + + copy_common_attributes(node, cloned); + + return cloned; +} + +} // namespace luci + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleConst *node) +{ + return clone_circleconst(node, _graph); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleConst.test.cpp b/compiler/luci/service/src/Nodes/CircleConst.test.cpp new file mode 100644 index 000000000..5d94798f4 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleConst.test.cpp @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/Nodes/CircleConst.h" +#include "luci/Service/CircleNodeClone.h" + +#include <loco.h> +#include <loco/IR/Graph.h> + +#include <gtest/gtest.h> + +namespace +{ + +luci::CircleConst *new_const_s32(loco::Graph *g) +{ + // prepare source CircleConst + auto circle_const = g->nodes()->create<luci::CircleConst>(); + + const auto size = 2; + + circle_const->dtype(loco::DataType::S32); + circle_const->rank(1); + circle_const->dim(0).set(size); + circle_const->shape_status(luci::ShapeStatus::VALID); + + circle_const->size<loco::DataType::S32>(size); + for (uint32_t i = 0; i < size; i++) + circle_const->at<loco::DataType::S32>(i) = i; + + // quantparam + auto quantparam = std::make_unique<luci::CircleQuantParam>(); + quantparam->scale = {1.0}; + quantparam->zerop = {0}; + quantparam->min = {-127.0}; + quantparam->max = {127.0}; + quantparam->quantized_dimension = 1; + circle_const->quantparam(std::move(quantparam)); + + // sparsityparam + auto sparam = std::make_unique<luci::SparsityParam>(); + sparam->traversal_order = {1}; + sparam->block_map = {1}; + sparam->dim_metadata = {}; + circle_const->sparsityparam(std::move(sparam)); + + return circle_const; +} + +template <loco::DataType DT> luci::CircleConst *new_empty_const(loco::Graph *g) +{ + auto circle_const = g->nodes()->create<luci::CircleConst>(); + + const auto size = 0; + + circle_const->dtype(DT); + circle_const->rank(1); + circle_const->dim(0).set(size); + circle_const->shape_status(luci::ShapeStatus::VALID); + circle_const->size<DT>(size); + + return circle_const; +} + +} // namespace + +TEST(CircleConstTest, clone) +{ + auto g = loco::make_graph(); + + // prepare source CircleConst + auto circle_const = new_const_s32(g.get()); + + // make a clone + auto const_cloned = luci::clone(circle_const); + + // check attributes + ASSERT_EQ(loco::DataType::S32, const_cloned->dtype()); + ASSERT_EQ(1, const_cloned->rank()); + ASSERT_EQ(2, const_cloned->dim(0).value()); + ASSERT_EQ(2, const_cloned->size<loco::DataType::S32>()); + ASSERT_EQ(0, const_cloned->at<loco::DataType::S32>(0)); + ASSERT_EQ(1, const_cloned->at<loco::DataType::S32>(1)); + ASSERT_NE(nullptr, const_cloned->quantparam()); + ASSERT_NE(nullptr, const_cloned->sparsityparam()); +} + +TEST(CircleConstTest, clone_U8) +{ + auto g = loco::make_graph(); + + // prepare source CircleConst + auto circle_const = new_empty_const<loco::DataType::U8>(g.get()); + + // make a clone + auto const_cloned = luci::clone(circle_const); + + // check attributes + ASSERT_EQ(loco::DataType::U8, const_cloned->dtype()); +} + +TEST(CircleConstTest, clone_S8) +{ + auto g = loco::make_graph(); + + // prepare source CircleConst + auto circle_const = new_empty_const<loco::DataType::S8>(g.get()); + + // make a clone + auto const_cloned = luci::clone(circle_const); + + // check attributes + ASSERT_EQ(loco::DataType::S8, const_cloned->dtype()); +} + +TEST(CircleConstTest, clone_S64) +{ + auto g = loco::make_graph(); + + // prepare source CircleConst + auto circle_const = new_empty_const<loco::DataType::S64>(g.get()); + + // make a clone + auto const_cloned = luci::clone(circle_const); + + // check attributes + ASSERT_EQ(loco::DataType::S64, const_cloned->dtype()); +} + +TEST(CircleConstTest, clone_BOOL) +{ + auto g = loco::make_graph(); + + // prepare source CircleConst + auto circle_const = new_empty_const<loco::DataType::BOOL>(g.get()); + + // make a clone + auto const_cloned = luci::clone(circle_const); + + // check attributes + ASSERT_EQ(loco::DataType::BOOL, const_cloned->dtype()); +} + +TEST(CloneNodeTest, clone_Const) +{ + auto g = loco::make_graph(); + auto node_const = new_const_s32(g.get()); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_const, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_const = dynamic_cast<luci::CircleConst *>(cloned); + ASSERT_NE(nullptr, cloned_const); + ASSERT_EQ(loco::DataType::S32, cloned_const->dtype()); + ASSERT_EQ(1, cloned_const->rank()); + ASSERT_EQ(2, cloned_const->dim(0).value()); + ASSERT_EQ(2, cloned_const->size<loco::DataType::S32>()); + ASSERT_EQ(0, cloned_const->at<loco::DataType::S32>(0)); + ASSERT_EQ(1, cloned_const->at<loco::DataType::S32>(1)); + ASSERT_NE(nullptr, cloned_const->quantparam()); + ASSERT_NE(nullptr, cloned_const->sparsityparam()); +} diff --git a/compiler/luci/service/src/Nodes/CircleConv2D.cpp b/compiler/luci/service/src/Nodes/CircleConv2D.cpp new file mode 100644 index 000000000..08cd87ef7 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleConv2D.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleConv2D *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + if (node->padding() == luci::Padding::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleConv2D>(); + if (cloned != nullptr) + { + cloned->fusedActivationFunction(node->fusedActivationFunction()); + cloned->padding(node->padding()); + cloned->stride()->h(node->stride()->h()); + cloned->stride()->w(node->stride()->w()); + cloned->dilation()->h(node->dilation()->h()); + cloned->dilation()->w(node->dilation()->w()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleConv2D.test.cpp b/compiler/luci/service/src/Nodes/CircleConv2D.test.cpp new file mode 100644 index 000000000..c265d6cd1 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleConv2D.test.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Conv2D) +{ + auto g = loco::make_graph(); + auto node_conv2d = g->nodes()->create<luci::CircleConv2D>(); + node_conv2d->fusedActivationFunction(luci::FusedActFunc::RELU); + node_conv2d->padding(luci::Padding::SAME); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_conv2d, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_conv2d = dynamic_cast<luci::CircleConv2D *>(cloned); + ASSERT_NE(nullptr, cloned_conv2d); + ASSERT_EQ(node_conv2d->fusedActivationFunction(), cloned_conv2d->fusedActivationFunction()); + ASSERT_EQ(node_conv2d->padding(), cloned_conv2d->padding()); +} + +TEST(CloneNodeTest, clone_Conv2D_fusedact_NEG) +{ + auto g = loco::make_graph(); + auto node_conv2d = g->nodes()->create<luci::CircleConv2D>(); + node_conv2d->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + node_conv2d->padding(luci::Padding::SAME); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_conv2d, gc.get()); + ASSERT_EQ(nullptr, cloned); +} + +TEST(CloneNodeTest, clone_Conv2D_padding_NEG) +{ + auto g = loco::make_graph(); + auto node_conv2d = g->nodes()->create<luci::CircleConv2D>(); + node_conv2d->fusedActivationFunction(luci::FusedActFunc::RELU); + node_conv2d->padding(luci::Padding::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_conv2d, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleCos.cpp b/compiler/luci/service/src/Nodes/CircleCos.cpp new file mode 100644 index 000000000..c46e3741b --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleCos.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleCos *) +{ + return _graph->nodes()->create<luci::CircleCos>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleCos.test.cpp b/compiler/luci/service/src/Nodes/CircleCos.test.cpp new file mode 100644 index 000000000..a25943b98 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleCos.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Cos) +{ + auto g = loco::make_graph(); + auto node_cos = g->nodes()->create<luci::CircleCos>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_cos, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_cos = dynamic_cast<luci::CircleCos *>(cloned); + ASSERT_NE(nullptr, cloned_cos); +} diff --git a/compiler/luci/service/src/Nodes/CircleCustom.cpp b/compiler/luci/service/src/Nodes/CircleCustom.cpp new file mode 100644 index 000000000..a9764c373 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleCustom.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleCustom *node) +{ + uint32_t num_in = node->numInputs(); + uint32_t num_out = node->numOutputs(); + auto *cloned = _graph->nodes()->create<luci::CircleCustom>(num_in, num_out); + if (cloned != nullptr) + { + cloned->custom_options(node->custom_options()); + cloned->custom_code(node->custom_code()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleCustom.test.cpp b/compiler/luci/service/src/Nodes/CircleCustom.test.cpp new file mode 100644 index 000000000..6fee68e71 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleCustom.test.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +#include <string> +#include <vector> + +TEST(CloneNodeTest, clone_Custom) +{ + auto g = loco::make_graph(); + auto node_custom = g->nodes()->create<luci::CircleCustom>(2, 3); + std::vector<uint8_t> options({0x55, 0x56, 0x57}); + std::string code = "hello"; + node_custom->custom_options(options); + node_custom->custom_code(code); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_custom, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_custom = dynamic_cast<luci::CircleCustom *>(cloned); + ASSERT_NE(nullptr, cloned_custom); + auto cloned_options = cloned_custom->custom_options(); + ASSERT_EQ(options.size(), cloned_options.size()); + auto size = options.size(); + for (size_t s = 0; s < size; ++s) + ASSERT_EQ(options.at(s), cloned_options.at(s)); + ASSERT_TRUE(node_custom->custom_code() == cloned_custom->custom_code()); +} diff --git a/compiler/luci/service/src/Nodes/CircleCustomOut.cpp b/compiler/luci/service/src/Nodes/CircleCustomOut.cpp new file mode 100644 index 000000000..84577f529 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleCustomOut.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleCustomOut *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleCustomOut>(); + if (cloned != nullptr) + cloned->index(node->index()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleCustomOut.test.cpp b/compiler/luci/service/src/Nodes/CircleCustomOut.test.cpp new file mode 100644 index 000000000..15121bab6 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleCustomOut.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_CustomOut) +{ + auto g = loco::make_graph(); + auto node_cout = g->nodes()->create<luci::CircleCustomOut>(); + node_cout->index(1); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_cout, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_cout = dynamic_cast<luci::CircleCustomOut *>(cloned); + ASSERT_NE(nullptr, cloned_cout); + ASSERT_EQ(node_cout->index(), cloned_cout->index()); +} diff --git a/compiler/luci/service/src/Nodes/CircleDepthToSpace.cpp b/compiler/luci/service/src/Nodes/CircleDepthToSpace.cpp new file mode 100644 index 000000000..7e0bc7d74 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleDepthToSpace.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleDepthToSpace *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleDepthToSpace>(); + if (cloned != nullptr) + cloned->block_size(node->block_size()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleDepthToSpace.test.cpp b/compiler/luci/service/src/Nodes/CircleDepthToSpace.test.cpp new file mode 100644 index 000000000..192b10b90 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleDepthToSpace.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_DepthToSpace) +{ + auto g = loco::make_graph(); + auto node_d2s = g->nodes()->create<luci::CircleDepthToSpace>(); + node_d2s->block_size(32); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_d2s, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_d2s = dynamic_cast<luci::CircleDepthToSpace *>(cloned); + ASSERT_NE(nullptr, cloned_d2s); + ASSERT_EQ(node_d2s->block_size(), cloned_d2s->block_size()); +} diff --git a/compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.cpp b/compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.cpp new file mode 100644 index 000000000..8e0b23d94 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleDepthwiseConv2D *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + if (node->padding() == luci::Padding::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleDepthwiseConv2D>(); + if (cloned != nullptr) + { + cloned->fusedActivationFunction(node->fusedActivationFunction()); + cloned->padding(node->padding()); + cloned->stride()->h(node->stride()->h()); + cloned->stride()->w(node->stride()->w()); + cloned->depthMultiplier(node->depthMultiplier()); + cloned->dilation()->h(node->dilation()->h()); + cloned->dilation()->w(node->dilation()->w()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.test.cpp b/compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.test.cpp new file mode 100644 index 000000000..8657464bc --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.test.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_DepthwiseConv2D) +{ + auto g = loco::make_graph(); + auto node_dwconv2d = g->nodes()->create<luci::CircleDepthwiseConv2D>(); + node_dwconv2d->fusedActivationFunction(luci::FusedActFunc::RELU); + node_dwconv2d->padding(luci::Padding::SAME); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_dwconv2d, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_dwconv2d = dynamic_cast<luci::CircleDepthwiseConv2D *>(cloned); + ASSERT_NE(nullptr, cloned_dwconv2d); + ASSERT_EQ(node_dwconv2d->fusedActivationFunction(), cloned_dwconv2d->fusedActivationFunction()); + ASSERT_EQ(node_dwconv2d->padding(), cloned_dwconv2d->padding()); +} + +TEST(CloneNodeTest, clone_DepthwiseConv2D_fusedact_NEG) +{ + auto g = loco::make_graph(); + auto node_dwconv2d = g->nodes()->create<luci::CircleDepthwiseConv2D>(); + node_dwconv2d->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + node_dwconv2d->padding(luci::Padding::SAME); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_dwconv2d, gc.get()); + ASSERT_EQ(nullptr, cloned); +} + +TEST(CloneNodeTest, clone_DepthwiseConv2D_padding_NEG) +{ + auto g = loco::make_graph(); + auto node_dwconv2d = g->nodes()->create<luci::CircleDepthwiseConv2D>(); + node_dwconv2d->fusedActivationFunction(luci::FusedActFunc::RELU); + node_dwconv2d->padding(luci::Padding::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_dwconv2d, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleDequantize.cpp b/compiler/luci/service/src/Nodes/CircleDequantize.cpp new file mode 100644 index 000000000..79983e4d3 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleDequantize.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleDequantize *) +{ + return _graph->nodes()->create<luci::CircleDequantize>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleDequantize.test.cpp b/compiler/luci/service/src/Nodes/CircleDequantize.test.cpp new file mode 100644 index 000000000..e1c563acf --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleDequantize.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Dequantize) +{ + auto g = loco::make_graph(); + auto node_dq = g->nodes()->create<luci::CircleDequantize>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_dq, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_dq = dynamic_cast<luci::CircleDequantize *>(cloned); + ASSERT_NE(nullptr, cloned_dq); +} diff --git a/compiler/luci/service/src/Nodes/CircleDiv.cpp b/compiler/luci/service/src/Nodes/CircleDiv.cpp new file mode 100644 index 000000000..7c48d8b76 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleDiv.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleDiv *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleDiv>(); + if (cloned != nullptr) + cloned->fusedActivationFunction(node->fusedActivationFunction()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleDiv.test.cpp b/compiler/luci/service/src/Nodes/CircleDiv.test.cpp new file mode 100644 index 000000000..5182ac908 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleDiv.test.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Div) +{ + auto g = loco::make_graph(); + auto node_div = g->nodes()->create<luci::CircleDiv>(); + node_div->fusedActivationFunction(luci::FusedActFunc::RELU); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_div, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_div = dynamic_cast<luci::CircleDiv *>(cloned); + ASSERT_NE(nullptr, cloned_div); + ASSERT_EQ(node_div->fusedActivationFunction(), cloned_div->fusedActivationFunction()); +} + +TEST(CloneNodeTest, clone_Div_NEG) +{ + auto g = loco::make_graph(); + auto node_div = g->nodes()->create<luci::CircleDiv>(); + node_div->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_div, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleElu.cpp b/compiler/luci/service/src/Nodes/CircleElu.cpp new file mode 100644 index 000000000..e2df30285 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleElu.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleElu *) +{ + return _graph->nodes()->create<luci::CircleElu>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleElu.test.cpp b/compiler/luci/service/src/Nodes/CircleElu.test.cpp new file mode 100644 index 000000000..e75b3bcb1 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleElu.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Elu) +{ + auto g = loco::make_graph(); + auto node_elu = g->nodes()->create<luci::CircleElu>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_elu, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_elu = dynamic_cast<luci::CircleElu *>(cloned); + ASSERT_NE(nullptr, cloned_elu); +} diff --git a/compiler/luci/service/src/Nodes/CircleEqual.cpp b/compiler/luci/service/src/Nodes/CircleEqual.cpp new file mode 100644 index 000000000..5dd382d0b --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleEqual.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleEqual *) +{ + return _graph->nodes()->create<luci::CircleEqual>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleEqual.test.cpp b/compiler/luci/service/src/Nodes/CircleEqual.test.cpp new file mode 100644 index 000000000..99a5535fc --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleEqual.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Equal) +{ + auto g = loco::make_graph(); + auto node_eq = g->nodes()->create<luci::CircleEqual>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_eq, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_eq = dynamic_cast<luci::CircleEqual *>(cloned); + ASSERT_NE(nullptr, cloned_eq); +} diff --git a/compiler/luci/service/src/Nodes/CircleExp.cpp b/compiler/luci/service/src/Nodes/CircleExp.cpp new file mode 100644 index 000000000..3d4918320 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleExp.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleExp *) +{ + return _graph->nodes()->create<luci::CircleExp>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleExp.test.cpp b/compiler/luci/service/src/Nodes/CircleExp.test.cpp new file mode 100644 index 000000000..ff2bb65db --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleExp.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Exp) +{ + auto g = loco::make_graph(); + auto node_exp = g->nodes()->create<luci::CircleExp>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_exp, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_exp = dynamic_cast<luci::CircleExp *>(cloned); + ASSERT_NE(nullptr, cloned_exp); +} diff --git a/compiler/luci/service/src/Nodes/CircleExpandDims.cpp b/compiler/luci/service/src/Nodes/CircleExpandDims.cpp new file mode 100644 index 000000000..4dd1cec86 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleExpandDims.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleExpandDims *) +{ + return _graph->nodes()->create<luci::CircleExpandDims>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleExpandDims.test.cpp b/compiler/luci/service/src/Nodes/CircleExpandDims.test.cpp new file mode 100644 index 000000000..e3481bccd --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleExpandDims.test.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Service/CircleShapeInference.h> + +#include <loco/IR/TensorShape.h> + +#include <gtest/gtest.h> + +TEST(ShapeRuleTest, simple_expand_dims) +{ + luci::CircleInput input; + luci::CircleConst axis; + luci::CircleExpandDims expand_dims; + + input.shape({4, 3}); + input.shape_status(luci::ShapeStatus::VALID); + + axis.dtype(loco::DataType::S32); + axis.rank(0); + axis.size<loco::DataType::S32>(1); + axis.at<loco::DataType::S32>(0) = 1; + axis.shape_status(luci::ShapeStatus::VALID); + + expand_dims.input(&input); + expand_dims.axis(&axis); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&expand_dims, shape)); + ASSERT_EQ(3, shape.rank()); + ASSERT_EQ(4, shape.dim(0).value()); + ASSERT_EQ(1, shape.dim(1).value()); + ASSERT_EQ(3, shape.dim(2).value()); +} + +TEST(CloneNodeTest, clone_ExpandDims) +{ + auto g = loco::make_graph(); + auto node_ed = g->nodes()->create<luci::CircleExpandDims>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_ed, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_ed = dynamic_cast<luci::CircleExpandDims *>(cloned); + ASSERT_NE(nullptr, cloned_ed); +} diff --git a/compiler/luci/service/src/Nodes/CircleFakeQuant.cpp b/compiler/luci/service/src/Nodes/CircleFakeQuant.cpp new file mode 100644 index 000000000..7abaca685 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleFakeQuant.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleFakeQuant *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleFakeQuant>(); + if (cloned != nullptr) + { + cloned->min(node->min()); + cloned->max(node->max()); + cloned->num_bits(node->num_bits()); + cloned->narrow_range(node->narrow_range()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleFakeQuant.test.cpp b/compiler/luci/service/src/Nodes/CircleFakeQuant.test.cpp new file mode 100644 index 000000000..2c4e3b836 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleFakeQuant.test.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_FakeQuant) +{ + auto g = loco::make_graph(); + auto node_fq = g->nodes()->create<luci::CircleFakeQuant>(); + node_fq->min(1.0f); + node_fq->max(2.0f); + node_fq->num_bits(8); + node_fq->narrow_range(true); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_fq, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_fq = dynamic_cast<luci::CircleFakeQuant *>(cloned); + ASSERT_NE(nullptr, cloned_fq); + ASSERT_EQ(node_fq->min(), cloned_fq->min()); + ASSERT_EQ(node_fq->max(), cloned_fq->max()); + ASSERT_EQ(node_fq->num_bits(), cloned_fq->num_bits()); + ASSERT_EQ(node_fq->narrow_range(), cloned_fq->narrow_range()); +} diff --git a/compiler/luci/service/src/Nodes/CircleFill.cpp b/compiler/luci/service/src/Nodes/CircleFill.cpp new file mode 100644 index 000000000..d9b74c63a --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleFill.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleFill *) +{ + return _graph->nodes()->create<luci::CircleFill>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleFill.test.cpp b/compiler/luci/service/src/Nodes/CircleFill.test.cpp new file mode 100644 index 000000000..56c807585 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleFill.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Fill) +{ + auto g = loco::make_graph(); + auto node_fill = g->nodes()->create<luci::CircleFill>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_fill, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_fill = dynamic_cast<luci::CircleFill *>(cloned); + ASSERT_NE(nullptr, cloned_fill); +} diff --git a/compiler/luci/service/src/Nodes/CircleFloor.cpp b/compiler/luci/service/src/Nodes/CircleFloor.cpp new file mode 100644 index 000000000..532808bc8 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleFloor.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleFloor *) +{ + return _graph->nodes()->create<luci::CircleFloor>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleFloor.test.cpp b/compiler/luci/service/src/Nodes/CircleFloor.test.cpp new file mode 100644 index 000000000..3d53fd2c3 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleFloor.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Floor) +{ + auto g = loco::make_graph(); + auto node_floor = g->nodes()->create<luci::CircleFloor>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_floor, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_floor = dynamic_cast<luci::CircleFloor *>(cloned); + ASSERT_NE(nullptr, cloned_floor); +} diff --git a/compiler/luci/service/src/Nodes/CircleFloorDiv.cpp b/compiler/luci/service/src/Nodes/CircleFloorDiv.cpp new file mode 100644 index 000000000..65be3e868 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleFloorDiv.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleFloorDiv *) +{ + return _graph->nodes()->create<luci::CircleFloorDiv>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleFloorDiv.test.cpp b/compiler/luci/service/src/Nodes/CircleFloorDiv.test.cpp new file mode 100644 index 000000000..6365ccd3b --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleFloorDiv.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_FloorDiv) +{ + auto g = loco::make_graph(); + auto node_floordiv = g->nodes()->create<luci::CircleFloorDiv>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_floordiv, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_floordiv = dynamic_cast<luci::CircleFloorDiv *>(cloned); + ASSERT_NE(nullptr, cloned_floordiv); +} diff --git a/compiler/luci/service/src/Nodes/CircleFloorMod.cpp b/compiler/luci/service/src/Nodes/CircleFloorMod.cpp new file mode 100644 index 000000000..00e6a0499 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleFloorMod.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleFloorMod *) +{ + return _graph->nodes()->create<luci::CircleFloorMod>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleFloorMod.test.cpp b/compiler/luci/service/src/Nodes/CircleFloorMod.test.cpp new file mode 100644 index 000000000..ce91d5881 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleFloorMod.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_FloorMod) +{ + auto g = loco::make_graph(); + auto node_floormod = g->nodes()->create<luci::CircleFloorMod>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_floormod, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_floormod = dynamic_cast<luci::CircleFloorMod *>(cloned); + ASSERT_NE(nullptr, cloned_floormod); +} diff --git a/compiler/luci/service/src/Nodes/CircleFullyConnected.cpp b/compiler/luci/service/src/Nodes/CircleFullyConnected.cpp new file mode 100644 index 000000000..8acb35cbf --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleFullyConnected.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleFullyConnected *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + if (node->weights_format() == luci::CircleFullyConnected::WeightsFormat::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleFullyConnected>(); + if (cloned != nullptr) + { + cloned->fusedActivationFunction(node->fusedActivationFunction()); + cloned->weights_format(node->weights_format()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleFullyConnected.test.cpp b/compiler/luci/service/src/Nodes/CircleFullyConnected.test.cpp new file mode 100644 index 000000000..965b59130 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleFullyConnected.test.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_FullyConnected) +{ + auto g = loco::make_graph(); + auto node_fc = g->nodes()->create<luci::CircleFullyConnected>(); + node_fc->fusedActivationFunction(luci::FusedActFunc::RELU); + node_fc->weights_format(luci::CircleFullyConnected::WeightsFormat::DEFAULT); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_fc, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_fc = dynamic_cast<luci::CircleFullyConnected *>(cloned); + ASSERT_NE(nullptr, cloned_fc); + ASSERT_EQ(node_fc->fusedActivationFunction(), cloned_fc->fusedActivationFunction()); + ASSERT_EQ(node_fc->weights_format(), cloned_fc->weights_format()); +} + +TEST(CloneNodeTest, clone_FullyConnected_fusedact_NEG) +{ + auto g = loco::make_graph(); + auto node_fc = g->nodes()->create<luci::CircleFullyConnected>(); + node_fc->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + node_fc->weights_format(luci::CircleFullyConnected::WeightsFormat::DEFAULT); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_fc, gc.get()); + ASSERT_EQ(nullptr, cloned); +} + +TEST(CloneNodeTest, clone_FullyConnected_wf_NEG) +{ + auto g = loco::make_graph(); + auto node_fc = g->nodes()->create<luci::CircleFullyConnected>(); + node_fc->fusedActivationFunction(luci::FusedActFunc::RELU); + node_fc->weights_format(luci::CircleFullyConnected::WeightsFormat::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_fc, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleGather.cpp b/compiler/luci/service/src/Nodes/CircleGather.cpp new file mode 100644 index 000000000..072bdeabc --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleGather.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleGather *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleGather>(); + if (cloned != nullptr) + cloned->axis(node->axis()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleGather.test.cpp b/compiler/luci/service/src/Nodes/CircleGather.test.cpp new file mode 100644 index 000000000..f48dbdb67 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleGather.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Gather) +{ + auto g = loco::make_graph(); + auto node_gat = g->nodes()->create<luci::CircleGather>(); + node_gat->axis(3); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_gat, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_gat = dynamic_cast<luci::CircleGather *>(cloned); + ASSERT_NE(nullptr, cloned_gat); + ASSERT_EQ(node_gat->axis(), cloned_gat->axis()); +} diff --git a/compiler/luci/service/src/Nodes/CircleGatherNd.cpp b/compiler/luci/service/src/Nodes/CircleGatherNd.cpp new file mode 100644 index 000000000..df7ae6e79 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleGatherNd.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleGatherNd *) +{ + return _graph->nodes()->create<luci::CircleGatherNd>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleGatherNd.test.cpp b/compiler/luci/service/src/Nodes/CircleGatherNd.test.cpp new file mode 100644 index 000000000..3a705710c --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleGatherNd.test.cpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Service/CircleShapeInference.h> + +#include <loco/IR/TensorShape.h> + +#include <oops/InternalExn.h> + +#include <gtest/gtest.h> + +TEST(ShapeRuleTest, gather_nd_simple) +{ + luci::CircleInput input; + luci::CircleConst indices_const; + luci::CircleGatherNd gather_nd; + + input.shape({1, 4, 4, 3}); + indices_const.shape({1, 2, 3}); + + input.shape_status(luci::ShapeStatus::VALID); + indices_const.shape_status(luci::ShapeStatus::VALID); + + gather_nd.params(&input); + gather_nd.indices(&indices_const); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&gather_nd, shape)); + ASSERT_EQ(3, shape.rank()); + ASSERT_EQ(1, shape.dim(0).value()); + ASSERT_EQ(2, shape.dim(1).value()); + ASSERT_EQ(3, shape.dim(2).value()); +} + +TEST(ShapeRuleTest, gather_nd_slices) +{ + luci::CircleInput input; + luci::CircleConst indices_const; + luci::CircleGatherNd gather_nd; + + input.shape({1, 4, 4, 3}); + indices_const.shape({1, 2, 1}); + + input.shape_status(luci::ShapeStatus::VALID); + indices_const.shape_status(luci::ShapeStatus::VALID); + + gather_nd.params(&input); + gather_nd.indices(&indices_const); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&gather_nd, shape)); + ASSERT_EQ(5, shape.rank()); + ASSERT_EQ(1, shape.dim(0).value()); + ASSERT_EQ(2, shape.dim(1).value()); + ASSERT_EQ(4, shape.dim(2).value()); + ASSERT_EQ(4, shape.dim(3).value()); + ASSERT_EQ(3, shape.dim(4).value()); +} + +TEST(ShapeRuleTest, gather_nd_NEG) +{ + luci::CircleInput input; + luci::CircleConst indices_const; + luci::CircleGatherNd gather_nd; + + input.shape({1, 4, 4, 3}); + indices_const.shape({1, 2, 5}); + + input.shape_status(luci::ShapeStatus::VALID); + indices_const.shape_status(luci::ShapeStatus::VALID); + + gather_nd.params(&input); + gather_nd.indices(&indices_const); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_THROW(shape_inf_rule.infer(&gather_nd, shape), oops::InternalExn); +} + +TEST(CloneNodeTest, clone_GatherNd) +{ + auto g = loco::make_graph(); + auto node_gtnd = g->nodes()->create<luci::CircleGatherNd>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_gtnd, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_gtnd = dynamic_cast<luci::CircleGatherNd *>(cloned); + ASSERT_NE(nullptr, cloned_gtnd); +} diff --git a/compiler/luci/service/src/Nodes/CircleGreater.cpp b/compiler/luci/service/src/Nodes/CircleGreater.cpp new file mode 100644 index 000000000..366d955bf --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleGreater.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleGreater *) +{ + return _graph->nodes()->create<luci::CircleGreater>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleGreater.test.cpp b/compiler/luci/service/src/Nodes/CircleGreater.test.cpp new file mode 100644 index 000000000..6d2df61f0 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleGreater.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Greater) +{ + auto g = loco::make_graph(); + auto node_gt = g->nodes()->create<luci::CircleGreater>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_gt, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_gt = dynamic_cast<luci::CircleGreater *>(cloned); + ASSERT_NE(nullptr, cloned_gt); +} diff --git a/compiler/luci/service/src/Nodes/CircleGreaterEqual.cpp b/compiler/luci/service/src/Nodes/CircleGreaterEqual.cpp new file mode 100644 index 000000000..9705bbe1e --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleGreaterEqual.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleGreaterEqual *) +{ + return _graph->nodes()->create<luci::CircleGreaterEqual>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleGreaterEqual.test.cpp b/compiler/luci/service/src/Nodes/CircleGreaterEqual.test.cpp new file mode 100644 index 000000000..10387df3a --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleGreaterEqual.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_GreaterEqual) +{ + auto g = loco::make_graph(); + auto node_ge = g->nodes()->create<luci::CircleGreaterEqual>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_ge, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_ge = dynamic_cast<luci::CircleGreaterEqual *>(cloned); + ASSERT_NE(nullptr, cloned_ge); +} diff --git a/compiler/luci/service/src/Nodes/CircleIfOut.cpp b/compiler/luci/service/src/Nodes/CircleIfOut.cpp new file mode 100644 index 000000000..31ad7203f --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleIfOut.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeInference.h> +#include <luci/Service/CircleTypeInference.h> + +namespace +{ + +struct CircleIfOutGraphs +{ + loco::GraphOutput *then_graph_output; + loco::GraphOutput *else_graph_output; +}; + +} // namespace + +namespace +{ + +CircleIfOutGraphs get_out_graphs(const luci::CircleIfOut *node) +{ + CircleIfOutGraphs ret_out; + + /** + * @note IF operator type and shape are that of the "then" and "else" + * Graph Outputs. + */ + auto circle_if = loco::must_cast<const luci::CircleIf *>(node->input()); + + auto index = node->index(); + auto then_graph = circle_if->then_graph(); + auto else_graph = circle_if->else_graph(); + assert(then_graph != nullptr); + assert(else_graph != nullptr); + + // shape and type are assumed to be same + // these are checked at post_import_graph() in Import + auto then_outputs = loco::output_nodes(then_graph); + auto else_outputs = loco::output_nodes(else_graph); + assert(then_outputs.size() == else_outputs.size()); + assert(index < static_cast<int32_t>(then_outputs.size())); + + auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index)); + auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index)); + + auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items + auto else_graph_outputs = else_graph->outputs(); + assert(then_graph_outputs->size() == else_graph_outputs->size()); + + ret_out.then_graph_output = then_graph_outputs->at(then_out->index()); + ret_out.else_graph_output = else_graph_outputs->at(else_out->index()); + + return ret_out; +} + +} // namespace + +namespace luci +{ + +loco::TensorShape sinf::Algorithm::visit(const luci::CircleIfOut *node) +{ + auto graphs = get_out_graphs(node); + assert(*graphs.then_graph_output->shape() == *graphs.else_graph_output->shape()); + return *graphs.then_graph_output->shape(); +} + +loco::DataType tinf::Algorithm::visit(const luci::CircleIfOut *node) +{ + auto graphs = get_out_graphs(node); + assert(graphs.then_graph_output->dtype() == graphs.else_graph_output->dtype()); + return graphs.then_graph_output->dtype(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleInstanceNorm.cpp b/compiler/luci/service/src/Nodes/CircleInstanceNorm.cpp new file mode 100644 index 000000000..d9e49d8ed --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleInstanceNorm.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleInstanceNorm *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleInstanceNorm>(); + if (cloned != nullptr) + { + cloned->fusedActivationFunction(node->fusedActivationFunction()); + cloned->epsilon(node->epsilon()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleInstanceNorm.test.cpp b/compiler/luci/service/src/Nodes/CircleInstanceNorm.test.cpp new file mode 100644 index 000000000..bae92b1ae --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleInstanceNorm.test.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_InstanceNorm) +{ + auto g = loco::make_graph(); + auto node_fc = g->nodes()->create<luci::CircleInstanceNorm>(); + node_fc->fusedActivationFunction(luci::FusedActFunc::RELU); + node_fc->epsilon(3); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_fc, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_fc = dynamic_cast<luci::CircleInstanceNorm *>(cloned); + ASSERT_NE(nullptr, cloned_fc); + ASSERT_EQ(node_fc->fusedActivationFunction(), cloned_fc->fusedActivationFunction()); + ASSERT_EQ(node_fc->epsilon(), cloned_fc->epsilon()); +} + +TEST(CloneNodeTest, clone_InstanceNorm_fusedact_NEG) +{ + auto g = loco::make_graph(); + auto node_fc = g->nodes()->create<luci::CircleInstanceNorm>(); + node_fc->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_fc, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleL2Normalize.cpp b/compiler/luci/service/src/Nodes/CircleL2Normalize.cpp new file mode 100644 index 000000000..afa2a6acb --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleL2Normalize.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleL2Normalize *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleL2Normalize>(); + if (cloned != nullptr) + cloned->fusedActivationFunction(node->fusedActivationFunction()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleL2Normalize.test.cpp b/compiler/luci/service/src/Nodes/CircleL2Normalize.test.cpp new file mode 100644 index 000000000..0f148797e --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleL2Normalize.test.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_L2Normalize) +{ + auto g = loco::make_graph(); + auto node_l2n = g->nodes()->create<luci::CircleL2Normalize>(); + node_l2n->fusedActivationFunction(luci::FusedActFunc::RELU); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_l2n, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_l2n = dynamic_cast<luci::CircleL2Normalize *>(cloned); + ASSERT_NE(nullptr, cloned_l2n); + ASSERT_EQ(node_l2n->fusedActivationFunction(), cloned_l2n->fusedActivationFunction()); +} + +TEST(CloneNodeTest, clone_L2Normalize_NEG) +{ + auto g = loco::make_graph(); + auto node_l2n = g->nodes()->create<luci::CircleL2Normalize>(); + node_l2n->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_l2n, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleL2Pool2D.cpp b/compiler/luci/service/src/Nodes/CircleL2Pool2D.cpp new file mode 100644 index 000000000..2d876c5bc --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleL2Pool2D.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleL2Pool2D *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + if (node->padding() == luci::Padding::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleL2Pool2D>(); + if (cloned != nullptr) + { + cloned->fusedActivationFunction(node->fusedActivationFunction()); + cloned->padding(node->padding()); + cloned->filter()->h(node->filter()->h()); + cloned->filter()->w(node->filter()->w()); + cloned->stride()->h(node->stride()->h()); + cloned->stride()->w(node->stride()->w()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleL2Pool2D.test.cpp b/compiler/luci/service/src/Nodes/CircleL2Pool2D.test.cpp new file mode 100644 index 000000000..37344fd9a --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleL2Pool2D.test.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_L2Pool2D) +{ + auto g = loco::make_graph(); + auto node_l2n = g->nodes()->create<luci::CircleL2Pool2D>(); + node_l2n->fusedActivationFunction(luci::FusedActFunc::RELU); + node_l2n->padding(luci::Padding::SAME); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_l2n, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_l2n = dynamic_cast<luci::CircleL2Pool2D *>(cloned); + ASSERT_NE(nullptr, cloned_l2n); + ASSERT_EQ(node_l2n->fusedActivationFunction(), cloned_l2n->fusedActivationFunction()); + ASSERT_EQ(node_l2n->padding(), cloned_l2n->padding()); +} + +TEST(CloneNodeTest, clone_L2Normalize_fusedact_NEG) +{ + auto g = loco::make_graph(); + auto node_l2n = g->nodes()->create<luci::CircleL2Pool2D>(); + node_l2n->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + node_l2n->padding(luci::Padding::SAME); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_l2n, gc.get()); + ASSERT_EQ(nullptr, cloned); +} + +TEST(CloneNodeTest, clone_L2Normalize_padding_NEG) +{ + auto g = loco::make_graph(); + auto node_l2n = g->nodes()->create<luci::CircleL2Pool2D>(); + node_l2n->fusedActivationFunction(luci::FusedActFunc::RELU); + node_l2n->padding(luci::Padding::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_l2n, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleLeakyRelu.cpp b/compiler/luci/service/src/Nodes/CircleLeakyRelu.cpp new file mode 100644 index 000000000..91030618c --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLeakyRelu.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleLeakyRelu *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleLeakyRelu>(); + if (cloned != nullptr) + cloned->alpha(node->alpha()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleLeakyRelu.test.cpp b/compiler/luci/service/src/Nodes/CircleLeakyRelu.test.cpp new file mode 100644 index 000000000..17fc1442a --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLeakyRelu.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_LeakyRelu) +{ + auto g = loco::make_graph(); + auto node_lr = g->nodes()->create<luci::CircleLeakyRelu>(); + node_lr->alpha(1.2f); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_lr, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_lr = dynamic_cast<luci::CircleLeakyRelu *>(cloned); + ASSERT_NE(nullptr, cloned_lr); + ASSERT_EQ(node_lr->alpha(), cloned_lr->alpha()); +} diff --git a/compiler/luci/service/src/Nodes/CircleLess.cpp b/compiler/luci/service/src/Nodes/CircleLess.cpp new file mode 100644 index 000000000..33b70b735 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLess.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleLess *) +{ + return _graph->nodes()->create<luci::CircleLess>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleLess.test.cpp b/compiler/luci/service/src/Nodes/CircleLess.test.cpp new file mode 100644 index 000000000..43248948d --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLess.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Less) +{ + auto g = loco::make_graph(); + auto node_less = g->nodes()->create<luci::CircleLess>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_less, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_less = dynamic_cast<luci::CircleLess *>(cloned); + ASSERT_NE(nullptr, cloned_less); +} diff --git a/compiler/luci/service/src/Nodes/CircleLessEqual.cpp b/compiler/luci/service/src/Nodes/CircleLessEqual.cpp new file mode 100644 index 000000000..22491365a --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLessEqual.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleLessEqual *) +{ + return _graph->nodes()->create<luci::CircleLessEqual>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleLessEqual.test.cpp b/compiler/luci/service/src/Nodes/CircleLessEqual.test.cpp new file mode 100644 index 000000000..0a87daf5d --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLessEqual.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_LessEqual) +{ + auto g = loco::make_graph(); + auto node_le = g->nodes()->create<luci::CircleLessEqual>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_le, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_le = dynamic_cast<luci::CircleLessEqual *>(cloned); + ASSERT_NE(nullptr, cloned_le); +} diff --git a/compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.cpp b/compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.cpp new file mode 100644 index 000000000..bf69b5ef5 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleLocalResponseNormalization *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleLocalResponseNormalization>(); + if (cloned != nullptr) + { + cloned->radius(node->radius()); + cloned->bias(node->bias()); + cloned->alpha(node->alpha()); + cloned->beta(node->beta()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.test.cpp b/compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.test.cpp new file mode 100644 index 000000000..262b119bb --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.test.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_LocalResponseNormalization) +{ + auto g = loco::make_graph(); + auto node_lrn = g->nodes()->create<luci::CircleLocalResponseNormalization>(); + node_lrn->radius(32); + node_lrn->bias(1.2f); + node_lrn->alpha(3.4f); + node_lrn->beta(5.7f); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_lrn, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_lrn = dynamic_cast<luci::CircleLocalResponseNormalization *>(cloned); + ASSERT_NE(nullptr, cloned_lrn); + ASSERT_EQ(node_lrn->radius(), cloned_lrn->radius()); + ASSERT_EQ(node_lrn->bias(), cloned_lrn->bias()); + ASSERT_EQ(node_lrn->alpha(), cloned_lrn->alpha()); + ASSERT_EQ(node_lrn->beta(), cloned_lrn->beta()); +} diff --git a/compiler/luci/service/src/Nodes/CircleLog.cpp b/compiler/luci/service/src/Nodes/CircleLog.cpp new file mode 100644 index 000000000..5788f129f --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLog.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleLog *) +{ + return _graph->nodes()->create<luci::CircleLog>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleLog.test.cpp b/compiler/luci/service/src/Nodes/CircleLog.test.cpp new file mode 100644 index 000000000..d1ee1428e --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLog.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Log) +{ + auto g = loco::make_graph(); + auto node_log = g->nodes()->create<luci::CircleLog>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_log, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_log = dynamic_cast<luci::CircleLog *>(cloned); + ASSERT_NE(nullptr, cloned_log); +} diff --git a/compiler/luci/service/src/Nodes/CircleLogSoftmax.cpp b/compiler/luci/service/src/Nodes/CircleLogSoftmax.cpp new file mode 100644 index 000000000..352160aff --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLogSoftmax.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleLogSoftmax *) +{ + return _graph->nodes()->create<luci::CircleLogSoftmax>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleLogSoftmax.test.cpp b/compiler/luci/service/src/Nodes/CircleLogSoftmax.test.cpp new file mode 100644 index 000000000..feebb79cb --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLogSoftmax.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_LogSoftmax) +{ + auto g = loco::make_graph(); + auto node_logs = g->nodes()->create<luci::CircleLogSoftmax>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_logs, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_logs = dynamic_cast<luci::CircleLogSoftmax *>(cloned); + ASSERT_NE(nullptr, cloned_logs); +} diff --git a/compiler/luci/service/src/Nodes/CircleLogicalAnd.cpp b/compiler/luci/service/src/Nodes/CircleLogicalAnd.cpp new file mode 100644 index 000000000..5df62b951 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLogicalAnd.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleLogicalAnd *) +{ + return _graph->nodes()->create<luci::CircleLogicalAnd>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleLogicalAnd.test.cpp b/compiler/luci/service/src/Nodes/CircleLogicalAnd.test.cpp new file mode 100644 index 000000000..aa811edfa --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLogicalAnd.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_LogicalAnd) +{ + auto g = loco::make_graph(); + auto node_logand = g->nodes()->create<luci::CircleLogicalAnd>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_logand, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_logand = dynamic_cast<luci::CircleLogicalAnd *>(cloned); + ASSERT_NE(nullptr, cloned_logand); +} diff --git a/compiler/luci/service/src/Nodes/CircleLogicalNot.cpp b/compiler/luci/service/src/Nodes/CircleLogicalNot.cpp new file mode 100644 index 000000000..ac982829d --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLogicalNot.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleLogicalNot *) +{ + return _graph->nodes()->create<luci::CircleLogicalNot>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleLogicalNot.test.cpp b/compiler/luci/service/src/Nodes/CircleLogicalNot.test.cpp new file mode 100644 index 000000000..9e55be944 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLogicalNot.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_LogicalNot) +{ + auto g = loco::make_graph(); + auto node_lognot = g->nodes()->create<luci::CircleLogicalNot>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_lognot, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_lognot = dynamic_cast<luci::CircleLogicalNot *>(cloned); + ASSERT_NE(nullptr, cloned_lognot); +} diff --git a/compiler/luci/service/src/Nodes/CircleLogicalOr.cpp b/compiler/luci/service/src/Nodes/CircleLogicalOr.cpp new file mode 100644 index 000000000..1201d6f34 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLogicalOr.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleLogicalOr *) +{ + return _graph->nodes()->create<luci::CircleLogicalOr>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleLogicalOr.test.cpp b/compiler/luci/service/src/Nodes/CircleLogicalOr.test.cpp new file mode 100644 index 000000000..19b706dcd --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLogicalOr.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_LogicalOr) +{ + auto g = loco::make_graph(); + auto node_logor = g->nodes()->create<luci::CircleLogicalOr>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_logor, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_logor = dynamic_cast<luci::CircleLogicalOr *>(cloned); + ASSERT_NE(nullptr, cloned_logor); +} diff --git a/compiler/luci/service/src/Nodes/CircleLogistic.cpp b/compiler/luci/service/src/Nodes/CircleLogistic.cpp new file mode 100644 index 000000000..b21b187e9 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLogistic.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleLogistic *) +{ + return _graph->nodes()->create<luci::CircleLogistic>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleLogistic.test.cpp b/compiler/luci/service/src/Nodes/CircleLogistic.test.cpp new file mode 100644 index 000000000..05dbe46e4 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleLogistic.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Logistic) +{ + auto g = loco::make_graph(); + auto node_log = g->nodes()->create<luci::CircleLogistic>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_log, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_log = dynamic_cast<luci::CircleLogistic *>(cloned); + ASSERT_NE(nullptr, cloned_log); +} diff --git a/compiler/luci/service/src/Nodes/CircleMatrixDiag.cpp b/compiler/luci/service/src/Nodes/CircleMatrixDiag.cpp new file mode 100644 index 000000000..2bffa07b1 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMatrixDiag.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleMatrixDiag *) +{ + return _graph->nodes()->create<luci::CircleMatrixDiag>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleMatrixDiag.test.cpp b/compiler/luci/service/src/Nodes/CircleMatrixDiag.test.cpp new file mode 100644 index 000000000..c08c4cb94 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMatrixDiag.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_MatrixDiag) +{ + auto g = loco::make_graph(); + auto node_md = g->nodes()->create<luci::CircleMatrixDiag>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_md, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_md = dynamic_cast<luci::CircleMatrixDiag *>(cloned); + ASSERT_NE(nullptr, cloned_md); +} diff --git a/compiler/luci/service/src/Nodes/CircleMatrixSetDiag.cpp b/compiler/luci/service/src/Nodes/CircleMatrixSetDiag.cpp new file mode 100644 index 000000000..5ea2a5339 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMatrixSetDiag.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleMatrixSetDiag *) +{ + return _graph->nodes()->create<luci::CircleMatrixSetDiag>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleMatrixSetDiag.test.cpp b/compiler/luci/service/src/Nodes/CircleMatrixSetDiag.test.cpp new file mode 100644 index 000000000..5ea77ba75 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMatrixSetDiag.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_MatrixSetDiag) +{ + auto g = loco::make_graph(); + auto node_msd = g->nodes()->create<luci::CircleMatrixSetDiag>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_msd, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_msd = dynamic_cast<luci::CircleMatrixSetDiag *>(cloned); + ASSERT_NE(nullptr, cloned_msd); +} diff --git a/compiler/luci/service/src/Nodes/CircleMaxPool2D.cpp b/compiler/luci/service/src/Nodes/CircleMaxPool2D.cpp new file mode 100644 index 000000000..b21610c7f --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMaxPool2D.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleMaxPool2D *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + if (node->padding() == luci::Padding::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleMaxPool2D>(); + if (cloned != nullptr) + { + cloned->fusedActivationFunction(node->fusedActivationFunction()); + cloned->padding(node->padding()); + cloned->filter()->h(node->filter()->h()); + cloned->filter()->w(node->filter()->w()); + cloned->stride()->h(node->stride()->h()); + cloned->stride()->w(node->stride()->w()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleMaxPool2D.test.cpp b/compiler/luci/service/src/Nodes/CircleMaxPool2D.test.cpp new file mode 100644 index 000000000..415cf7c44 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMaxPool2D.test.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_MaxPool2D) +{ + auto g = loco::make_graph(); + auto node_mp = g->nodes()->create<luci::CircleMaxPool2D>(); + node_mp->fusedActivationFunction(luci::FusedActFunc::RELU); + node_mp->padding(luci::Padding::SAME); + node_mp->filter()->h(1); + node_mp->filter()->w(2); + node_mp->stride()->h(3); + node_mp->stride()->w(4); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_mp, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_mp = dynamic_cast<luci::CircleMaxPool2D *>(cloned); + ASSERT_NE(nullptr, cloned_mp); + ASSERT_EQ(node_mp->fusedActivationFunction(), cloned_mp->fusedActivationFunction()); + ASSERT_EQ(node_mp->padding(), cloned_mp->padding()); + ASSERT_EQ(node_mp->filter()->h(), cloned_mp->filter()->h()); + ASSERT_EQ(node_mp->filter()->w(), cloned_mp->filter()->w()); + ASSERT_EQ(node_mp->stride()->h(), cloned_mp->stride()->h()); + ASSERT_EQ(node_mp->stride()->w(), cloned_mp->stride()->w()); +} + +TEST(CloneNodeTest, clone_MaxPool2D_fusedact_NEG) +{ + auto g = loco::make_graph(); + auto node_mp = g->nodes()->create<luci::CircleMaxPool2D>(); + node_mp->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + node_mp->padding(luci::Padding::SAME); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_mp, gc.get()); + ASSERT_EQ(nullptr, cloned); +} + +TEST(CloneNodeTest, clone_MaxPool2D_padding_NEG) +{ + auto g = loco::make_graph(); + auto node_mp = g->nodes()->create<luci::CircleMaxPool2D>(); + node_mp->fusedActivationFunction(luci::FusedActFunc::RELU); + node_mp->padding(luci::Padding::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_mp, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleMaximum.cpp b/compiler/luci/service/src/Nodes/CircleMaximum.cpp new file mode 100644 index 000000000..545f4ca21 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMaximum.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleMaximum *) +{ + return _graph->nodes()->create<luci::CircleMaximum>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleMaximum.test.cpp b/compiler/luci/service/src/Nodes/CircleMaximum.test.cpp new file mode 100644 index 000000000..6f1ada060 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMaximum.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Maximum) +{ + auto g = loco::make_graph(); + auto node_max = g->nodes()->create<luci::CircleMaximum>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_max, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_max = dynamic_cast<luci::CircleMaximum *>(cloned); + ASSERT_NE(nullptr, cloned_max); +} diff --git a/compiler/luci/service/src/Nodes/CircleMean.cpp b/compiler/luci/service/src/Nodes/CircleMean.cpp index a78713698..95bc54532 100644 --- a/compiler/luci/service/src/Nodes/CircleMean.cpp +++ b/compiler/luci/service/src/Nodes/CircleMean.cpp @@ -1,11 +1,11 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,15 +14,17 @@ * limitations under the License. */ -#include <luci/Service/CircleShapeSignatureInference.h> +#include "CircleCloneNode.h" namespace luci { -ShapeSignature ssinf::Algorithm::visit(const luci::CircleMean *node) +luci::CircleNode *CloneNode::visit(const luci::CircleMean *node) { - return legalized_signature( - reduced_signature(node->input(), node->reduction_indices(), node->keep_dims())); + auto *cloned = _graph->nodes()->create<luci::CircleMean>(); + if (cloned != nullptr) + cloned->keep_dims(node->keep_dims()); + return cloned; } } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleMean.test.cpp b/compiler/luci/service/src/Nodes/CircleMean.test.cpp new file mode 100644 index 000000000..aa1b88f13 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMean.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Mean) +{ + auto g = loco::make_graph(); + auto node_mean = g->nodes()->create<luci::CircleMean>(); + node_mean->keep_dims(true); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_mean, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_mean = dynamic_cast<luci::CircleMean *>(cloned); + ASSERT_NE(nullptr, cloned_mean); + ASSERT_EQ(node_mean->keep_dims(), cloned_mean->keep_dims()); +} diff --git a/compiler/luci/service/src/Nodes/CircleMinimum.cpp b/compiler/luci/service/src/Nodes/CircleMinimum.cpp new file mode 100644 index 000000000..2c2755c55 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMinimum.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleMinimum *) +{ + return _graph->nodes()->create<luci::CircleMinimum>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleMinimum.test.cpp b/compiler/luci/service/src/Nodes/CircleMinimum.test.cpp new file mode 100644 index 000000000..0a54be71c --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMinimum.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Minimum) +{ + auto g = loco::make_graph(); + auto node_min = g->nodes()->create<luci::CircleMinimum>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_min, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_min = dynamic_cast<luci::CircleMinimum *>(cloned); + ASSERT_NE(nullptr, cloned_min); +} diff --git a/compiler/luci/service/src/Nodes/CircleMirrorPad.cpp b/compiler/luci/service/src/Nodes/CircleMirrorPad.cpp new file mode 100644 index 000000000..919221a0b --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMirrorPad.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleMirrorPad *node) +{ + if (node->mode() == luci::MirrorPadMode::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleMirrorPad>(); + if (cloned != nullptr) + cloned->mode(node->mode()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleMirrorPad.test.cpp b/compiler/luci/service/src/Nodes/CircleMirrorPad.test.cpp new file mode 100644 index 000000000..911cf6d3b --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMirrorPad.test.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_MirrorPad) +{ + auto g = loco::make_graph(); + auto node_mp = g->nodes()->create<luci::CircleMirrorPad>(); + node_mp->mode(luci::MirrorPadMode::REFLECT); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_mp, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_mp = dynamic_cast<luci::CircleMirrorPad *>(cloned); + ASSERT_NE(nullptr, cloned_mp); + ASSERT_EQ(node_mp->mode(), cloned_mp->mode()); +} + +TEST(CloneNodeTest, clone_MirrorPad_mode_NEG) +{ + auto g = loco::make_graph(); + auto node_mp = g->nodes()->create<luci::CircleMirrorPad>(); + node_mp->mode(luci::MirrorPadMode::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_mp, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleMul.cpp b/compiler/luci/service/src/Nodes/CircleMul.cpp new file mode 100644 index 000000000..096aed196 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMul.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleMul *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleMul>(); + if (cloned != nullptr) + cloned->fusedActivationFunction(node->fusedActivationFunction()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleMul.test.cpp b/compiler/luci/service/src/Nodes/CircleMul.test.cpp new file mode 100644 index 000000000..dc5565f11 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMul.test.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Mul) +{ + auto g = loco::make_graph(); + auto node_mul = g->nodes()->create<luci::CircleMul>(); + node_mul->fusedActivationFunction(luci::FusedActFunc::RELU); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_mul, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_mul = dynamic_cast<luci::CircleMul *>(cloned); + ASSERT_NE(nullptr, cloned_mul); + ASSERT_EQ(node_mul->fusedActivationFunction(), cloned_mul->fusedActivationFunction()); +} + +TEST(CloneNodeTest, clone_Mul_NEG) +{ + auto g = loco::make_graph(); + auto node_mul = g->nodes()->create<luci::CircleMul>(); + node_mul->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_mul, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleNeg.cpp b/compiler/luci/service/src/Nodes/CircleNeg.cpp new file mode 100644 index 000000000..312189e77 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleNeg.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleNeg *) +{ + return _graph->nodes()->create<luci::CircleNeg>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleNeg.test.cpp b/compiler/luci/service/src/Nodes/CircleNeg.test.cpp new file mode 100644 index 000000000..8c2880324 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleNeg.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Neg) +{ + auto g = loco::make_graph(); + auto node_neg = g->nodes()->create<luci::CircleNeg>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_neg, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_neg = dynamic_cast<luci::CircleNeg *>(cloned); + ASSERT_NE(nullptr, cloned_neg); +} diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.cpp new file mode 100644 index 000000000..4757e8314 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleNonMaxSuppressionV4 *) +{ + return _graph->nodes()->create<luci::CircleNonMaxSuppressionV4>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.test.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.test.cpp new file mode 100644 index 000000000..34f5b0325 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_NonMaxSuppressionV4) +{ + auto g = loco::make_graph(); + auto node_nms = g->nodes()->create<luci::CircleNonMaxSuppressionV4>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_nms, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_nms = dynamic_cast<luci::CircleNonMaxSuppressionV4 *>(cloned); + ASSERT_NE(nullptr, cloned_nms); +} diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.cpp new file mode 100644 index 000000000..2a12f2a45 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleNonMaxSuppressionV4Out *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleNonMaxSuppressionV4Out>(); + if (cloned != nullptr) + cloned->index(node->index()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.test.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.test.cpp new file mode 100644 index 000000000..ed9e0e019 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_NonMaxSuppressionV4Out) +{ + auto g = loco::make_graph(); + auto node_nout = g->nodes()->create<luci::CircleNonMaxSuppressionV4Out>(); + node_nout->index(1); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_nout, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_nout = dynamic_cast<luci::CircleNonMaxSuppressionV4Out *>(cloned); + ASSERT_NE(nullptr, cloned_nout); + ASSERT_EQ(node_nout->index(), cloned_nout->index()); +} diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.cpp new file mode 100644 index 000000000..34d128072 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleNonMaxSuppressionV5 *) +{ + return _graph->nodes()->create<luci::CircleNonMaxSuppressionV5>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.test.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.test.cpp new file mode 100644 index 000000000..faaee969e --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_NonMaxSuppressionV5) +{ + auto g = loco::make_graph(); + auto node_nms = g->nodes()->create<luci::CircleNonMaxSuppressionV5>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_nms, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_nms = dynamic_cast<luci::CircleNonMaxSuppressionV5 *>(cloned); + ASSERT_NE(nullptr, cloned_nms); +} diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.cpp new file mode 100644 index 000000000..e1d7875e7 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleNonMaxSuppressionV5Out *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleNonMaxSuppressionV5Out>(); + if (cloned != nullptr) + cloned->index(node->index()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.test.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.test.cpp new file mode 100644 index 000000000..ef0f766b9 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_NonMaxSuppressionV5Out) +{ + auto g = loco::make_graph(); + auto node_nout = g->nodes()->create<luci::CircleNonMaxSuppressionV5Out>(); + node_nout->index(1); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_nout, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_nout = dynamic_cast<luci::CircleNonMaxSuppressionV5Out *>(cloned); + ASSERT_NE(nullptr, cloned_nout); + ASSERT_EQ(node_nout->index(), cloned_nout->index()); +} diff --git a/compiler/luci/service/src/Nodes/CircleNotEqual.cpp b/compiler/luci/service/src/Nodes/CircleNotEqual.cpp new file mode 100644 index 000000000..4cb5320e8 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleNotEqual.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleNotEqual *) +{ + return _graph->nodes()->create<luci::CircleNotEqual>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleNotEqual.test.cpp b/compiler/luci/service/src/Nodes/CircleNotEqual.test.cpp new file mode 100644 index 000000000..20f7dbc4b --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleNotEqual.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_NotEqual) +{ + auto g = loco::make_graph(); + auto node_ne = g->nodes()->create<luci::CircleNotEqual>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_ne, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_ne = dynamic_cast<luci::CircleNotEqual *>(cloned); + ASSERT_NE(nullptr, cloned_ne); +} diff --git a/compiler/luci/service/src/Nodes/CircleOneHot.cpp b/compiler/luci/service/src/Nodes/CircleOneHot.cpp new file mode 100644 index 000000000..a33c8ff26 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleOneHot.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleOneHot *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleOneHot>(); + if (cloned != nullptr) + cloned->axis(node->axis()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleOneHot.test.cpp b/compiler/luci/service/src/Nodes/CircleOneHot.test.cpp new file mode 100644 index 000000000..dea927d1b --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleOneHot.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_OneHot) +{ + auto g = loco::make_graph(); + auto node_oh = g->nodes()->create<luci::CircleOneHot>(); + node_oh->axis(3); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_oh, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_oh = dynamic_cast<luci::CircleOneHot *>(cloned); + ASSERT_NE(nullptr, cloned_oh); + ASSERT_EQ(node_oh->axis(), cloned_oh->axis()); +} diff --git a/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp b/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp index e0f13c439..ce94dff94 100644 --- a/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp +++ b/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp @@ -1,11 +1,11 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,11 +14,14 @@ * limitations under the License. */ -#include <luci/Service/CircleShapeSignatureInference.h> +#include "CircleCloneNode.h" namespace luci { -ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutputDummy *) { return ShapeSignature(); } +luci::CircleNode *CloneNode::visit(const luci::CircleOutputDummy *) +{ + return _graph->nodes()->create<luci::CircleOutputDummy>(); +} } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleOutputDummy.test.cpp b/compiler/luci/service/src/Nodes/CircleOutputDummy.test.cpp new file mode 100644 index 000000000..6170c7c41 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleOutputDummy.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_OutputDummy) +{ + auto g = loco::make_graph(); + auto node_dummy = g->nodes()->create<luci::CircleOutputDummy>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_dummy, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_dummy = dynamic_cast<luci::CircleOutputDummy *>(cloned); + ASSERT_NE(nullptr, cloned_dummy); +} diff --git a/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp b/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp index 75bbbb3c0..1b0f919c3 100644 --- a/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp +++ b/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp @@ -1,11 +1,11 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,14 +14,14 @@ * limitations under the License. */ -#include <luci/Service/CircleShapeSignatureInference.h> +#include "CircleCloneNode.h" namespace luci { -ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutputExclude *) +luci::CircleNode *CloneNode::visit(const luci::CircleOutputExclude *) { - return ShapeSignature(); + return _graph->nodes()->create<luci::CircleOutputExclude>(); } } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleOutputExclude.test.cpp b/compiler/luci/service/src/Nodes/CircleOutputExclude.test.cpp new file mode 100644 index 000000000..120ffe86b --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleOutputExclude.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_OutputExclude) +{ + auto g = loco::make_graph(); + auto node_outex = g->nodes()->create<luci::CircleOutputExclude>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_outex, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_outex = dynamic_cast<luci::CircleOutputExclude *>(cloned); + ASSERT_NE(nullptr, cloned_outex); +} diff --git a/compiler/luci/service/src/Nodes/CirclePRelu.cpp b/compiler/luci/service/src/Nodes/CirclePRelu.cpp new file mode 100644 index 000000000..8a34e507e --- /dev/null +++ b/compiler/luci/service/src/Nodes/CirclePRelu.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CirclePRelu *) +{ + return _graph->nodes()->create<luci::CirclePRelu>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CirclePRelu.test.cpp b/compiler/luci/service/src/Nodes/CirclePRelu.test.cpp new file mode 100644 index 000000000..1150e3fa4 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CirclePRelu.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_PRelu) +{ + auto g = loco::make_graph(); + auto node_pr = g->nodes()->create<luci::CirclePRelu>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_pr, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_pr = dynamic_cast<luci::CirclePRelu *>(cloned); + ASSERT_NE(nullptr, cloned_pr); +} diff --git a/compiler/luci/service/src/Nodes/CirclePack.cpp b/compiler/luci/service/src/Nodes/CirclePack.cpp new file mode 100644 index 000000000..a3cee0bfd --- /dev/null +++ b/compiler/luci/service/src/Nodes/CirclePack.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CirclePack *node) +{ + auto *cloned = _graph->nodes()->create<luci::CirclePack>(node->values_count()); + if (cloned != nullptr) + cloned->axis(node->axis()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CirclePack.test.cpp b/compiler/luci/service/src/Nodes/CirclePack.test.cpp new file mode 100644 index 000000000..b808956dc --- /dev/null +++ b/compiler/luci/service/src/Nodes/CirclePack.test.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Pack) +{ + auto g = loco::make_graph(); + auto node_pack = g->nodes()->create<luci::CirclePack>(3); + node_pack->axis(7); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_pack, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_pack = dynamic_cast<luci::CirclePack *>(cloned); + ASSERT_NE(nullptr, cloned_pack); + ASSERT_EQ(node_pack->values_count(), cloned_pack->values_count()); + ASSERT_EQ(node_pack->axis(), cloned_pack->axis()); +} diff --git a/compiler/luci/service/src/Nodes/CirclePad.cpp b/compiler/luci/service/src/Nodes/CirclePad.cpp new file mode 100644 index 000000000..425bdce4d --- /dev/null +++ b/compiler/luci/service/src/Nodes/CirclePad.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CirclePad *) +{ + return _graph->nodes()->create<luci::CirclePad>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CirclePad.test.cpp b/compiler/luci/service/src/Nodes/CirclePad.test.cpp new file mode 100644 index 000000000..1d5f8375e --- /dev/null +++ b/compiler/luci/service/src/Nodes/CirclePad.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Pad) +{ + auto g = loco::make_graph(); + auto node_pad = g->nodes()->create<luci::CirclePad>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_pad, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_pad = dynamic_cast<luci::CirclePad *>(cloned); + ASSERT_NE(nullptr, cloned_pad); +} diff --git a/compiler/luci/service/src/Nodes/CirclePadV2.cpp b/compiler/luci/service/src/Nodes/CirclePadV2.cpp new file mode 100644 index 000000000..0e93869b6 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CirclePadV2.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CirclePadV2 *) +{ + return _graph->nodes()->create<luci::CirclePadV2>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CirclePadV2.test.cpp b/compiler/luci/service/src/Nodes/CirclePadV2.test.cpp new file mode 100644 index 000000000..d011f69f8 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CirclePadV2.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_PadV2) +{ + auto g = loco::make_graph(); + auto node_pad = g->nodes()->create<luci::CirclePadV2>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_pad, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_pad = dynamic_cast<luci::CirclePadV2 *>(cloned); + ASSERT_NE(nullptr, cloned_pad); +} diff --git a/compiler/luci/service/src/Nodes/CirclePow.cpp b/compiler/luci/service/src/Nodes/CirclePow.cpp new file mode 100644 index 000000000..bf9388913 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CirclePow.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CirclePow *) +{ + return _graph->nodes()->create<luci::CirclePow>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CirclePow.test.cpp b/compiler/luci/service/src/Nodes/CirclePow.test.cpp new file mode 100644 index 000000000..946298932 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CirclePow.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Pow) +{ + auto g = loco::make_graph(); + auto node_pow = g->nodes()->create<luci::CirclePow>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_pow, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_pow = dynamic_cast<luci::CirclePow *>(cloned); + ASSERT_NE(nullptr, cloned_pow); +} diff --git a/compiler/luci/service/src/Nodes/CircleRange.cpp b/compiler/luci/service/src/Nodes/CircleRange.cpp new file mode 100644 index 000000000..9c6f7b494 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRange.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleRange *) +{ + return _graph->nodes()->create<luci::CircleRange>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleRange.test.cpp b/compiler/luci/service/src/Nodes/CircleRange.test.cpp new file mode 100644 index 000000000..b2fb29617 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRange.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Range) +{ + auto g = loco::make_graph(); + auto node_range = g->nodes()->create<luci::CircleRange>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_range, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_range = dynamic_cast<luci::CircleRange *>(cloned); + ASSERT_NE(nullptr, cloned_range); +} diff --git a/compiler/luci/service/src/Nodes/CircleRank.cpp b/compiler/luci/service/src/Nodes/CircleRank.cpp new file mode 100644 index 000000000..db8171c51 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRank.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleRank *) +{ + return _graph->nodes()->create<luci::CircleRank>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleRank.test.cpp b/compiler/luci/service/src/Nodes/CircleRank.test.cpp new file mode 100644 index 000000000..0e81fb254 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRank.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Rank) +{ + auto g = loco::make_graph(); + auto node_rank = g->nodes()->create<luci::CircleRank>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_rank, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_rank = dynamic_cast<luci::CircleRank *>(cloned); + ASSERT_NE(nullptr, cloned_rank); +} diff --git a/compiler/luci/service/src/Nodes/CircleReduceAny.cpp b/compiler/luci/service/src/Nodes/CircleReduceAny.cpp index 27da81466..3ab0b3b59 100644 --- a/compiler/luci/service/src/Nodes/CircleReduceAny.cpp +++ b/compiler/luci/service/src/Nodes/CircleReduceAny.cpp @@ -1,11 +1,11 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,15 +14,17 @@ * limitations under the License. */ -#include <luci/Service/CircleShapeSignatureInference.h> +#include "CircleCloneNode.h" namespace luci { -ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceAny *node) +luci::CircleNode *CloneNode::visit(const luci::CircleReduceAny *node) { - return legalized_signature( - reduced_signature(node->input(), node->reduction_indices(), node->keep_dims())); + auto *cloned = _graph->nodes()->create<luci::CircleReduceAny>(); + if (cloned != nullptr) + cloned->keep_dims(node->keep_dims()); + return cloned; } } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReduceAny.test.cpp b/compiler/luci/service/src/Nodes/CircleReduceAny.test.cpp new file mode 100644 index 000000000..904b5a139 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReduceAny.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_ReduceAny) +{ + auto g = loco::make_graph(); + auto node_ra = g->nodes()->create<luci::CircleReduceAny>(); + node_ra->keep_dims(true); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_ra, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_ra = dynamic_cast<luci::CircleReduceAny *>(cloned); + ASSERT_NE(nullptr, cloned_ra); + ASSERT_EQ(node_ra->keep_dims(), cloned_ra->keep_dims()); +} diff --git a/compiler/luci/service/src/Nodes/CircleReduceMax.cpp b/compiler/luci/service/src/Nodes/CircleReduceMax.cpp index 48d9cb970..c026905ca 100644 --- a/compiler/luci/service/src/Nodes/CircleReduceMax.cpp +++ b/compiler/luci/service/src/Nodes/CircleReduceMax.cpp @@ -1,11 +1,11 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,15 +14,17 @@ * limitations under the License. */ -#include <luci/Service/CircleShapeSignatureInference.h> +#include "CircleCloneNode.h" namespace luci { -ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceMax *node) +luci::CircleNode *CloneNode::visit(const luci::CircleReduceMax *node) { - return legalized_signature( - reduced_signature(node->input(), node->reduction_indices(), node->keep_dims())); + auto *cloned = _graph->nodes()->create<luci::CircleReduceMax>(); + if (cloned != nullptr) + cloned->keep_dims(node->keep_dims()); + return cloned; } } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReduceMax.test.cpp b/compiler/luci/service/src/Nodes/CircleReduceMax.test.cpp new file mode 100644 index 000000000..b3f3c881e --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReduceMax.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_ReduceMax) +{ + auto g = loco::make_graph(); + auto node_rmax = g->nodes()->create<luci::CircleReduceMax>(); + node_rmax->keep_dims(true); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_rmax, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_rmax = dynamic_cast<luci::CircleReduceMax *>(cloned); + ASSERT_NE(nullptr, cloned_rmax); + ASSERT_EQ(node_rmax->keep_dims(), cloned_rmax->keep_dims()); +} diff --git a/compiler/luci/service/src/Nodes/CircleReduceMin.cpp b/compiler/luci/service/src/Nodes/CircleReduceMin.cpp index 9a9997118..3dfa19680 100644 --- a/compiler/luci/service/src/Nodes/CircleReduceMin.cpp +++ b/compiler/luci/service/src/Nodes/CircleReduceMin.cpp @@ -1,11 +1,11 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,15 +14,17 @@ * limitations under the License. */ -#include <luci/Service/CircleShapeSignatureInference.h> +#include "CircleCloneNode.h" namespace luci { -ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceMin *node) +luci::CircleNode *CloneNode::visit(const luci::CircleReduceMin *node) { - return legalized_signature( - reduced_signature(node->input(), node->reduction_indices(), node->keep_dims())); + auto *cloned = _graph->nodes()->create<luci::CircleReduceMin>(); + if (cloned != nullptr) + cloned->keep_dims(node->keep_dims()); + return cloned; } } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReduceMin.test.cpp b/compiler/luci/service/src/Nodes/CircleReduceMin.test.cpp new file mode 100644 index 000000000..b3faa68da --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReduceMin.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_ReduceMin) +{ + auto g = loco::make_graph(); + auto node_rmin = g->nodes()->create<luci::CircleReduceMin>(); + node_rmin->keep_dims(true); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_rmin, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_rmin = dynamic_cast<luci::CircleReduceMin *>(cloned); + ASSERT_NE(nullptr, cloned_rmin); + ASSERT_EQ(node_rmin->keep_dims(), cloned_rmin->keep_dims()); +} diff --git a/compiler/luci/service/src/Nodes/CircleReduceProd.cpp b/compiler/luci/service/src/Nodes/CircleReduceProd.cpp index a9d381a74..418a8ce32 100644 --- a/compiler/luci/service/src/Nodes/CircleReduceProd.cpp +++ b/compiler/luci/service/src/Nodes/CircleReduceProd.cpp @@ -1,11 +1,11 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,15 +14,17 @@ * limitations under the License. */ -#include <luci/Service/CircleShapeSignatureInference.h> +#include "CircleCloneNode.h" namespace luci { -ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceProd *node) +luci::CircleNode *CloneNode::visit(const luci::CircleReduceProd *node) { - return legalized_signature( - reduced_signature(node->input(), node->reduction_indices(), node->keep_dims())); + auto *cloned = _graph->nodes()->create<luci::CircleReduceProd>(); + if (cloned != nullptr) + cloned->keep_dims(node->keep_dims()); + return cloned; } } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReduceProd.test.cpp b/compiler/luci/service/src/Nodes/CircleReduceProd.test.cpp new file mode 100644 index 000000000..8caf8e91f --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReduceProd.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_ReduceProd) +{ + auto g = loco::make_graph(); + auto node_rp = g->nodes()->create<luci::CircleReduceProd>(); + node_rp->keep_dims(true); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_rp, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_rp = dynamic_cast<luci::CircleReduceProd *>(cloned); + ASSERT_NE(nullptr, cloned_rp); + ASSERT_EQ(node_rp->keep_dims(), cloned_rp->keep_dims()); +} diff --git a/compiler/luci/service/src/Nodes/CircleRelu.cpp b/compiler/luci/service/src/Nodes/CircleRelu.cpp index a7a7f6f0a..7447eea0c 100644 --- a/compiler/luci/service/src/Nodes/CircleRelu.cpp +++ b/compiler/luci/service/src/Nodes/CircleRelu.cpp @@ -1,11 +1,11 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,14 +14,14 @@ * limitations under the License. */ -#include <luci/Service/CircleShapeSignatureInference.h> +#include "CircleCloneNode.h" namespace luci { -ShapeSignature ssinf::Algorithm::visit(const luci::CircleRelu *node) +luci::CircleNode *CloneNode::visit(const luci::CircleRelu *) { - return input_arg_signature(node, 0); + return _graph->nodes()->create<luci::CircleRelu>(); } } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleRelu.test.cpp b/compiler/luci/service/src/Nodes/CircleRelu.test.cpp new file mode 100644 index 000000000..6154376ba --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRelu.test.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Service/CircleShapeInference.h> +#include <luci/Service/CircleTypeInference.h> + +#include <loco/IR/TensorShape.h> + +#include <gtest/gtest.h> + +TEST(ShapeRuleTest, simple_relu) +{ + luci::CircleInput input; + luci::CircleRelu relu; + + input.shape({3, 4}); + input.shape_status(luci::ShapeStatus::VALID); + + relu.features(&input); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&relu, shape)); + ASSERT_EQ(2, shape.rank()); + ASSERT_EQ(3, shape.dim(0).value()); + ASSERT_EQ(4, shape.dim(1).value()); +} + +TEST(DataTypeRuleTest, simple_relu) +{ + luci::CircleInput input; + luci::CircleRelu relu; + + input.dtype(loco::DataType::S32); + + relu.features(&input); + + loco::DataType dtype; + luci::tinf::Rule type_inf_rule; + + ASSERT_TRUE(type_inf_rule.infer(&relu, dtype)); + ASSERT_EQ(loco::DataType::S32, dtype); +} + +TEST(CloneNodeTest, clone_Relu) +{ + auto g = loco::make_graph(); + auto node_relu = g->nodes()->create<luci::CircleRelu>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_relu, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_relu = dynamic_cast<luci::CircleRelu *>(cloned); + ASSERT_NE(nullptr, cloned_relu); +} diff --git a/compiler/luci/service/src/Nodes/CircleRelu6.cpp b/compiler/luci/service/src/Nodes/CircleRelu6.cpp index 92a596d08..7b98311ed 100644 --- a/compiler/luci/service/src/Nodes/CircleRelu6.cpp +++ b/compiler/luci/service/src/Nodes/CircleRelu6.cpp @@ -1,11 +1,11 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,14 +14,14 @@ * limitations under the License. */ -#include <luci/Service/CircleShapeSignatureInference.h> +#include "CircleCloneNode.h" namespace luci { -ShapeSignature ssinf::Algorithm::visit(const luci::CircleRelu6 *node) +luci::CircleNode *CloneNode::visit(const luci::CircleRelu6 *) { - return input_arg_signature(node, 0); + return _graph->nodes()->create<luci::CircleRelu6>(); } } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleRelu6.test.cpp b/compiler/luci/service/src/Nodes/CircleRelu6.test.cpp new file mode 100644 index 000000000..213dbcb09 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRelu6.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <loco/IR/TensorShape.h> + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Relu6) +{ + auto g = loco::make_graph(); + auto node_relu6 = g->nodes()->create<luci::CircleRelu6>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_relu6, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_relu6 = dynamic_cast<luci::CircleRelu6 *>(cloned); + ASSERT_NE(nullptr, cloned_relu6); +} diff --git a/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp b/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp index 1e8d9971d..4efedb9fc 100644 --- a/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp +++ b/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp @@ -1,11 +1,11 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,14 +14,14 @@ * limitations under the License. */ -#include <luci/Service/CircleShapeSignatureInference.h> +#include "CircleCloneNode.h" namespace luci { -ShapeSignature ssinf::Algorithm::visit(const luci::CircleReluN1To1 *node) +luci::CircleNode *CloneNode::visit(const luci::CircleReluN1To1 *) { - return input_arg_signature(node, 0); + return _graph->nodes()->create<luci::CircleReluN1To1>(); } } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReluN1To1.test.cpp b/compiler/luci/service/src/Nodes/CircleReluN1To1.test.cpp new file mode 100644 index 000000000..b828e795c --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReluN1To1.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <loco/IR/TensorShape.h> + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_ReluN1To1) +{ + auto g = loco::make_graph(); + auto node_relun1 = g->nodes()->create<luci::CircleReluN1To1>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_relun1, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_relun1 = dynamic_cast<luci::CircleReluN1To1 *>(cloned); + ASSERT_NE(nullptr, cloned_relun1); +} diff --git a/compiler/luci/service/src/Nodes/CircleReshape.cpp b/compiler/luci/service/src/Nodes/CircleReshape.cpp new file mode 100644 index 000000000..07a81b306 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReshape.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleReshape *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleReshape>(); + if (cloned != nullptr) + { + uint32_t rank = node->newShape()->rank(); + cloned->newShape()->rank(rank); + for (uint32_t r = 0; r < rank; ++r) + { + cloned->newShape()->dim(r) = node->newShape()->dim(r); + } + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp new file mode 100644 index 000000000..ca92b717d --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Reshape) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create<luci::CircleReshape>(); + node_reshape->newShape()->rank(2); + node_reshape->newShape()->dim(0) = 3; + node_reshape->newShape()->dim(1) = 4; + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_reshape, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_reshape = dynamic_cast<luci::CircleReshape *>(cloned); + ASSERT_NE(nullptr, cloned_reshape); + ASSERT_EQ(node_reshape->newShape()->rank(), cloned_reshape->newShape()->rank()); + ASSERT_EQ(node_reshape->newShape()->dim(0), cloned_reshape->newShape()->dim(0)); + ASSERT_EQ(node_reshape->newShape()->dim(1), cloned_reshape->newShape()->dim(1)); +} diff --git a/compiler/luci/service/src/Nodes/CircleResizeBilinear.cpp b/compiler/luci/service/src/Nodes/CircleResizeBilinear.cpp new file mode 100644 index 000000000..55d21af45 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleResizeBilinear.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleResizeBilinear *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleResizeBilinear>(); + if (cloned != nullptr) + { + cloned->align_corners(node->align_corners()); + cloned->half_pixel_centers(node->half_pixel_centers()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleResizeBilinear.test.cpp b/compiler/luci/service/src/Nodes/CircleResizeBilinear.test.cpp new file mode 100644 index 000000000..bff71261d --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleResizeBilinear.test.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Service/CircleShapeInference.h> + +#include <loco/IR/TensorShape.h> + +#include <gtest/gtest.h> + +TEST(ShapeRuleTest, resize_bilinear_simple) +{ + luci::CircleInput input; + luci::CircleConst rb_size; + luci::CircleResizeBilinear rb; + + input.shape({1, 4, 4, 3}); + input.shape_status(luci::ShapeStatus::VALID); + + rb_size.dtype(loco::DataType::S32); + rb_size.rank(1); + rb_size.dim(0).set(2); + rb_size.size<loco::DataType::S32>(2); + rb_size.at<loco::DataType::S32>(0) = 16; + rb_size.at<loco::DataType::S32>(1) = 16; + rb_size.shape_status(luci::ShapeStatus::VALID); + + rb.input(&input); + rb.size(&rb_size); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&rb, shape)); + ASSERT_EQ(4, shape.rank()); + ASSERT_EQ(1, shape.dim(0).value()); + ASSERT_EQ(16, shape.dim(1).value()); + ASSERT_EQ(16, shape.dim(2).value()); + ASSERT_EQ(3, shape.dim(3).value()); +} + +TEST(CloneNodeTest, clone_ResizeBilinear) +{ + auto g = loco::make_graph(); + auto node_rb = g->nodes()->create<luci::CircleResizeBilinear>(); + node_rb->align_corners(true); + node_rb->half_pixel_centers(true); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_rb, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_rb = dynamic_cast<luci::CircleResizeBilinear *>(cloned); + ASSERT_NE(nullptr, cloned_rb); + ASSERT_EQ(node_rb->align_corners(), cloned_rb->align_corners()); + ASSERT_EQ(node_rb->half_pixel_centers(), cloned_rb->half_pixel_centers()); +} diff --git a/compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.cpp b/compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.cpp new file mode 100644 index 000000000..5727786a7 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleResizeNearestNeighbor *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleResizeNearestNeighbor>(); + if (cloned != nullptr) + cloned->align_corners(node->align_corners()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.test.cpp b/compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.test.cpp new file mode 100644 index 000000000..a1d781c65 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.test.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Service/CircleShapeInference.h> + +#include <loco/IR/TensorShape.h> + +#include <gtest/gtest.h> + +TEST(ShapeRuleTest, resize_nearest_neighbor_simple) +{ + luci::CircleInput input; + luci::CircleConst rnn_size; + luci::CircleResizeNearestNeighbor rnn; + + input.shape({1, 4, 4, 3}); + input.shape_status(luci::ShapeStatus::VALID); + + rnn_size.dtype(loco::DataType::S32); + rnn_size.rank(1); + rnn_size.dim(0).set(2); + rnn_size.size<loco::DataType::S32>(2); + rnn_size.at<loco::DataType::S32>(0) = 16; + rnn_size.at<loco::DataType::S32>(1) = 16; + rnn_size.shape_status(luci::ShapeStatus::VALID); + + rnn.input(&input); + rnn.size(&rnn_size); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&rnn, shape)); + ASSERT_EQ(4, shape.rank()); + ASSERT_EQ(1, shape.dim(0).value()); + ASSERT_EQ(16, shape.dim(1).value()); + ASSERT_EQ(16, shape.dim(2).value()); + ASSERT_EQ(3, shape.dim(3).value()); +} + +TEST(CloneNodeTest, clone_ResizeNearestNeighbor) +{ + auto g = loco::make_graph(); + auto node_rnn = g->nodes()->create<luci::CircleResizeNearestNeighbor>(); + node_rnn->align_corners(true); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_rnn, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_rnn = dynamic_cast<luci::CircleResizeNearestNeighbor *>(cloned); + ASSERT_NE(nullptr, cloned_rnn); + ASSERT_EQ(node_rnn->align_corners(), cloned_rnn->align_corners()); +} diff --git a/compiler/luci/service/src/Nodes/CircleReverseSequence.cpp b/compiler/luci/service/src/Nodes/CircleReverseSequence.cpp new file mode 100644 index 000000000..6e6919b0c --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReverseSequence.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleReverseSequence *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleReverseSequence>(); + if (cloned != nullptr) + { + cloned->seq_axis(node->seq_axis()); + cloned->batch_axis(node->batch_axis()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReverseSequence.test.cpp b/compiler/luci/service/src/Nodes/CircleReverseSequence.test.cpp new file mode 100644 index 000000000..a7a8e3949 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReverseSequence.test.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_ReverseSequence) +{ + auto g = loco::make_graph(); + auto node_rs = g->nodes()->create<luci::CircleReverseSequence>(); + node_rs->seq_axis(1); + node_rs->batch_axis(2); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_rs, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_rs = dynamic_cast<luci::CircleReverseSequence *>(cloned); + ASSERT_NE(nullptr, cloned_rs); + ASSERT_EQ(node_rs->seq_axis(), cloned_rs->seq_axis()); + ASSERT_EQ(node_rs->batch_axis(), cloned_rs->batch_axis()); +} diff --git a/compiler/luci/service/src/Nodes/CircleReverseV2.cpp b/compiler/luci/service/src/Nodes/CircleReverseV2.cpp new file mode 100644 index 000000000..e8fee6c3e --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReverseV2.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleReverseV2 *) +{ + return _graph->nodes()->create<luci::CircleReverseV2>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReverseV2.test.cpp b/compiler/luci/service/src/Nodes/CircleReverseV2.test.cpp new file mode 100644 index 000000000..0e5ff933c --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReverseV2.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_ReverseV2) +{ + auto g = loco::make_graph(); + auto node_rev = g->nodes()->create<luci::CircleReverseV2>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_rev, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_rev = dynamic_cast<luci::CircleReverseV2 *>(cloned); + ASSERT_NE(nullptr, cloned_rev); +} diff --git a/compiler/luci/service/src/Nodes/CircleRound.cpp b/compiler/luci/service/src/Nodes/CircleRound.cpp new file mode 100644 index 000000000..2c23f2df6 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRound.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleRound *) +{ + return _graph->nodes()->create<luci::CircleRound>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleRound.test.cpp b/compiler/luci/service/src/Nodes/CircleRound.test.cpp new file mode 100644 index 000000000..2c2c3a9d0 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRound.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Round) +{ + auto g = loco::make_graph(); + auto node_rnd = g->nodes()->create<luci::CircleRound>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_rnd, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_rnd = dynamic_cast<luci::CircleRound *>(cloned); + ASSERT_NE(nullptr, cloned_rnd); +} diff --git a/compiler/luci/service/src/Nodes/CircleRsqrt.cpp b/compiler/luci/service/src/Nodes/CircleRsqrt.cpp new file mode 100644 index 000000000..aca702fe1 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRsqrt.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleRsqrt *) +{ + return _graph->nodes()->create<luci::CircleRsqrt>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleRsqrt.test.cpp b/compiler/luci/service/src/Nodes/CircleRsqrt.test.cpp new file mode 100644 index 000000000..3e4ced562 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRsqrt.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Rsqrt) +{ + auto g = loco::make_graph(); + auto node_rsqrt = g->nodes()->create<luci::CircleRsqrt>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_rsqrt, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_rsqrt = dynamic_cast<luci::CircleRsqrt *>(cloned); + ASSERT_NE(nullptr, cloned_rsqrt); +} diff --git a/compiler/luci/service/src/Nodes/CircleScatterNd.cpp b/compiler/luci/service/src/Nodes/CircleScatterNd.cpp new file mode 100644 index 000000000..6c477a598 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleScatterNd.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleScatterNd *) +{ + return _graph->nodes()->create<luci::CircleScatterNd>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleScatterNd.test.cpp b/compiler/luci/service/src/Nodes/CircleScatterNd.test.cpp new file mode 100644 index 000000000..ce63603cc --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleScatterNd.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_ScatterNd) +{ + auto g = loco::make_graph(); + auto node_snd = g->nodes()->create<luci::CircleScatterNd>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_snd, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_snd = dynamic_cast<luci::CircleScatterNd *>(cloned); + ASSERT_NE(nullptr, cloned_snd); +} diff --git a/compiler/luci/service/src/Nodes/CircleSegmentSum.cpp b/compiler/luci/service/src/Nodes/CircleSegmentSum.cpp new file mode 100644 index 000000000..aa4001f57 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSegmentSum.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSegmentSum *) +{ + return _graph->nodes()->create<luci::CircleSegmentSum>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSegmentSum.test.cpp b/compiler/luci/service/src/Nodes/CircleSegmentSum.test.cpp new file mode 100644 index 000000000..ff17b0745 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSegmentSum.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_SegmentSum) +{ + auto g = loco::make_graph(); + auto node_ss = g->nodes()->create<luci::CircleSegmentSum>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_ss, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_ss = dynamic_cast<luci::CircleSegmentSum *>(cloned); + ASSERT_NE(nullptr, cloned_ss); +} diff --git a/compiler/luci/service/src/Nodes/CircleSelect.cpp b/compiler/luci/service/src/Nodes/CircleSelect.cpp new file mode 100644 index 000000000..71b31d33f --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSelect.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSelect *) +{ + return _graph->nodes()->create<luci::CircleSelect>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSelect.test.cpp b/compiler/luci/service/src/Nodes/CircleSelect.test.cpp new file mode 100644 index 000000000..e8d631618 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSelect.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Select) +{ + auto g = loco::make_graph(); + auto node_sel = g->nodes()->create<luci::CircleSelect>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_sel, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_sel = dynamic_cast<luci::CircleSelect *>(cloned); + ASSERT_NE(nullptr, cloned_sel); +} diff --git a/compiler/luci/service/src/Nodes/CircleSelectV2.cpp b/compiler/luci/service/src/Nodes/CircleSelectV2.cpp new file mode 100644 index 000000000..07af40c40 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSelectV2.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSelectV2 *) +{ + return _graph->nodes()->create<luci::CircleSelectV2>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSelectV2.test.cpp b/compiler/luci/service/src/Nodes/CircleSelectV2.test.cpp new file mode 100644 index 000000000..253dba555 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSelectV2.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_SelectV2) +{ + auto g = loco::make_graph(); + auto node_sel = g->nodes()->create<luci::CircleSelectV2>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_sel, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_sel = dynamic_cast<luci::CircleSelectV2 *>(cloned); + ASSERT_NE(nullptr, cloned_sel); +} diff --git a/compiler/luci/service/src/Nodes/CircleShape.cpp b/compiler/luci/service/src/Nodes/CircleShape.cpp new file mode 100644 index 000000000..e5b5fa28f --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleShape.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleShape *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleShape>(); + if (cloned != nullptr) + cloned->out_type(node->out_type()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleShape.test.cpp b/compiler/luci/service/src/Nodes/CircleShape.test.cpp new file mode 100644 index 000000000..ec057bd05 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleShape.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Shape) +{ + auto g = loco::make_graph(); + auto node_shape = g->nodes()->create<luci::CircleShape>(); + node_shape->out_type(loco::DataType::S32); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_shape, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_shape = dynamic_cast<luci::CircleShape *>(cloned); + ASSERT_NE(nullptr, cloned_shape); + ASSERT_EQ(node_shape->out_type(), cloned_shape->out_type()); +} diff --git a/compiler/luci/service/src/Nodes/CircleSin.cpp b/compiler/luci/service/src/Nodes/CircleSin.cpp new file mode 100644 index 000000000..46a07d21d --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSin.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSin *) +{ + return _graph->nodes()->create<luci::CircleSin>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSin.test.cpp b/compiler/luci/service/src/Nodes/CircleSin.test.cpp new file mode 100644 index 000000000..b072e7e2c --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSin.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Sin) +{ + auto g = loco::make_graph(); + auto node_sin = g->nodes()->create<luci::CircleSin>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_sin, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_sin = dynamic_cast<luci::CircleSin *>(cloned); + ASSERT_NE(nullptr, cloned_sin); +} diff --git a/compiler/luci/service/src/Nodes/CircleSlice.cpp b/compiler/luci/service/src/Nodes/CircleSlice.cpp new file mode 100644 index 000000000..6b2f4a591 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSlice.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSlice *) +{ + return _graph->nodes()->create<luci::CircleSlice>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSlice.test.cpp b/compiler/luci/service/src/Nodes/CircleSlice.test.cpp new file mode 100644 index 000000000..48ec20304 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSlice.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Slice) +{ + auto g = loco::make_graph(); + auto node_slice = g->nodes()->create<luci::CircleSlice>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_slice, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_slice = dynamic_cast<luci::CircleSlice *>(cloned); + ASSERT_NE(nullptr, cloned_slice); +} diff --git a/compiler/luci/service/src/Nodes/CircleSoftmax.cpp b/compiler/luci/service/src/Nodes/CircleSoftmax.cpp new file mode 100644 index 000000000..359d1000c --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSoftmax.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSoftmax *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleSoftmax>(); + if (cloned != nullptr) + cloned->beta(node->beta()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSoftmax.test.cpp b/compiler/luci/service/src/Nodes/CircleSoftmax.test.cpp new file mode 100644 index 000000000..c80b44d69 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSoftmax.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Softmax) +{ + auto g = loco::make_graph(); + auto node_sm = g->nodes()->create<luci::CircleSoftmax>(); + node_sm->beta(2.3f); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_sm, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_sm = dynamic_cast<luci::CircleSoftmax *>(cloned); + ASSERT_NE(nullptr, cloned_sm); + ASSERT_EQ(node_sm->beta(), cloned_sm->beta()); +} diff --git a/compiler/luci/service/src/Nodes/CircleSpaceToBatchND.cpp b/compiler/luci/service/src/Nodes/CircleSpaceToBatchND.cpp new file mode 100644 index 000000000..feb4f3e37 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSpaceToBatchND.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSpaceToBatchND *) +{ + return _graph->nodes()->create<luci::CircleSpaceToBatchND>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSpaceToBatchND.test.cpp b/compiler/luci/service/src/Nodes/CircleSpaceToBatchND.test.cpp new file mode 100644 index 000000000..eb743795d --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSpaceToBatchND.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_SpaceToBatchND) +{ + auto g = loco::make_graph(); + auto node_s2bnd = g->nodes()->create<luci::CircleSpaceToBatchND>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_s2bnd, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_s2bnd = dynamic_cast<luci::CircleSpaceToBatchND *>(cloned); + ASSERT_NE(nullptr, cloned_s2bnd); +} diff --git a/compiler/luci/service/src/Nodes/CircleSpaceToDepth.cpp b/compiler/luci/service/src/Nodes/CircleSpaceToDepth.cpp new file mode 100644 index 000000000..3a82f5c7a --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSpaceToDepth.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSpaceToDepth *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleSpaceToDepth>(); + if (cloned != nullptr) + cloned->block_size(node->block_size()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSpaceToDepth.test.cpp b/compiler/luci/service/src/Nodes/CircleSpaceToDepth.test.cpp new file mode 100644 index 000000000..fb544e6d7 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSpaceToDepth.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_SpaceToDepth) +{ + auto g = loco::make_graph(); + auto node_s2d = g->nodes()->create<luci::CircleSpaceToDepth>(); + node_s2d->block_size(32); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_s2d, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_s2d = dynamic_cast<luci::CircleSpaceToDepth *>(cloned); + ASSERT_NE(nullptr, cloned_s2d); + ASSERT_EQ(node_s2d->block_size(), cloned_s2d->block_size()); +} diff --git a/compiler/luci/service/src/Nodes/CircleSparseToDense.cpp b/compiler/luci/service/src/Nodes/CircleSparseToDense.cpp new file mode 100644 index 000000000..3dba1a542 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSparseToDense.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSparseToDense *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleSparseToDense>(); + if (cloned != nullptr) + cloned->validate_indices(node->validate_indices()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSparseToDense.test.cpp b/compiler/luci/service/src/Nodes/CircleSparseToDense.test.cpp new file mode 100644 index 000000000..177a469cd --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSparseToDense.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_SparseToDense) +{ + auto g = loco::make_graph(); + auto node_s2d = g->nodes()->create<luci::CircleSparseToDense>(); + node_s2d->validate_indices(true); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_s2d, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_s2d = dynamic_cast<luci::CircleSparseToDense *>(cloned); + ASSERT_NE(nullptr, cloned_s2d); + ASSERT_EQ(node_s2d->validate_indices(), cloned_s2d->validate_indices()); +} diff --git a/compiler/luci/service/src/Nodes/CircleSplit.cpp b/compiler/luci/service/src/Nodes/CircleSplit.cpp new file mode 100644 index 000000000..e68a24a1f --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSplit.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSplit *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleSplit>(); + if (cloned != nullptr) + cloned->num_split(node->num_split()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSplit.test.cpp b/compiler/luci/service/src/Nodes/CircleSplit.test.cpp new file mode 100644 index 000000000..9ee26b425 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSplit.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Split) +{ + auto g = loco::make_graph(); + auto node_split = g->nodes()->create<luci::CircleSplit>(); + node_split->num_split(5); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_split, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_split = dynamic_cast<luci::CircleSplit *>(cloned); + ASSERT_NE(nullptr, cloned_split); + ASSERT_EQ(node_split->num_split(), cloned_split->num_split()); +} diff --git a/compiler/luci/service/src/Nodes/CircleSplitOut.cpp b/compiler/luci/service/src/Nodes/CircleSplitOut.cpp new file mode 100644 index 000000000..024598892 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSplitOut.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSplitOut *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleSplitOut>(); + if (cloned != nullptr) + cloned->index(node->index()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSplitOut.test.cpp b/compiler/luci/service/src/Nodes/CircleSplitOut.test.cpp new file mode 100644 index 000000000..deec08804 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSplitOut.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_SplitOut) +{ + auto g = loco::make_graph(); + auto node_sout = g->nodes()->create<luci::CircleSplitOut>(); + node_sout->index(1); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_sout, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_sout = dynamic_cast<luci::CircleSplitOut *>(cloned); + ASSERT_NE(nullptr, cloned_sout); + ASSERT_EQ(node_sout->index(), cloned_sout->index()); +} diff --git a/compiler/luci/service/src/Nodes/CircleSplitV.cpp b/compiler/luci/service/src/Nodes/CircleSplitV.cpp new file mode 100644 index 000000000..de6c6cce6 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSplitV.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSplitV *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleSplitV>(); + if (cloned != nullptr) + cloned->num_split(node->num_split()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSplitV.test.cpp b/compiler/luci/service/src/Nodes/CircleSplitV.test.cpp new file mode 100644 index 000000000..d109a64aa --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSplitV.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_SplitV) +{ + auto g = loco::make_graph(); + auto node_split = g->nodes()->create<luci::CircleSplitV>(); + node_split->num_split(5); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_split, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_split = dynamic_cast<luci::CircleSplitV *>(cloned); + ASSERT_NE(nullptr, cloned_split); + ASSERT_EQ(node_split->num_split(), cloned_split->num_split()); +} diff --git a/compiler/luci/service/src/Nodes/CircleSplitVOut.cpp b/compiler/luci/service/src/Nodes/CircleSplitVOut.cpp new file mode 100644 index 000000000..f40eb0a47 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSplitVOut.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSplitVOut *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleSplitVOut>(); + if (cloned != nullptr) + cloned->index(node->index()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSplitVOut.test.cpp b/compiler/luci/service/src/Nodes/CircleSplitVOut.test.cpp new file mode 100644 index 000000000..ab5e9d6be --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSplitVOut.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_SplitVOut) +{ + auto g = loco::make_graph(); + auto node_sout = g->nodes()->create<luci::CircleSplitVOut>(); + node_sout->index(1); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_sout, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_sout = dynamic_cast<luci::CircleSplitVOut *>(cloned); + ASSERT_NE(nullptr, cloned_sout); + ASSERT_EQ(node_sout->index(), cloned_sout->index()); +} diff --git a/compiler/luci/service/src/Nodes/CircleSqrt.cpp b/compiler/luci/service/src/Nodes/CircleSqrt.cpp new file mode 100644 index 000000000..a3e63684b --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSqrt.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSqrt *) +{ + return _graph->nodes()->create<luci::CircleSqrt>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSqrt.test.cpp b/compiler/luci/service/src/Nodes/CircleSqrt.test.cpp new file mode 100644 index 000000000..dbef839d6 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSqrt.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Sqrt) +{ + auto g = loco::make_graph(); + auto node_sqrt = g->nodes()->create<luci::CircleSqrt>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_sqrt, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_sqrt = dynamic_cast<luci::CircleSqrt *>(cloned); + ASSERT_NE(nullptr, cloned_sqrt); +} diff --git a/compiler/luci/service/src/Nodes/CircleSquare.cpp b/compiler/luci/service/src/Nodes/CircleSquare.cpp new file mode 100644 index 000000000..88bbed76c --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSquare.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSquare *) +{ + return _graph->nodes()->create<luci::CircleSquare>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSquare.test.cpp b/compiler/luci/service/src/Nodes/CircleSquare.test.cpp new file mode 100644 index 000000000..67ac21210 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSquare.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Square) +{ + auto g = loco::make_graph(); + auto node_squ = g->nodes()->create<luci::CircleSquare>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_squ, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_squ = dynamic_cast<luci::CircleSquare *>(cloned); + ASSERT_NE(nullptr, cloned_squ); +} diff --git a/compiler/luci/service/src/Nodes/CircleSquaredDifference.cpp b/compiler/luci/service/src/Nodes/CircleSquaredDifference.cpp new file mode 100644 index 000000000..6becdf1c9 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSquaredDifference.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSquaredDifference *) +{ + return _graph->nodes()->create<luci::CircleSquaredDifference>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSquaredDifference.test.cpp b/compiler/luci/service/src/Nodes/CircleSquaredDifference.test.cpp new file mode 100644 index 000000000..26099612b --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSquaredDifference.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_SquaredDifference) +{ + auto g = loco::make_graph(); + auto node_sd = g->nodes()->create<luci::CircleSquaredDifference>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_sd, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_sd = dynamic_cast<luci::CircleSquaredDifference *>(cloned); + ASSERT_NE(nullptr, cloned_sd); +} diff --git a/compiler/luci/service/src/Nodes/CircleSqueeze.cpp b/compiler/luci/service/src/Nodes/CircleSqueeze.cpp new file mode 100644 index 000000000..02ba5020c --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSqueeze.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSqueeze *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleSqueeze>(); + if (cloned != nullptr) + cloned->squeeze_dims(node->squeeze_dims()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSqueeze.test.cpp b/compiler/luci/service/src/Nodes/CircleSqueeze.test.cpp new file mode 100644 index 000000000..bc73eafa7 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSqueeze.test.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Service/CircleShapeInference.h> + +#include <loco/IR/TensorShape.h> + +#include <gtest/gtest.h> + +TEST(ShapeRuleTest, squeeze_simple) +{ + luci::CircleInput input; + luci::CircleSqueeze squeeze; + + input.shape({1, 4, 3, 1}); + input.shape_status(luci::ShapeStatus::VALID); + + squeeze.input(&input); + squeeze.squeeze_dims({0}); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&squeeze, shape)); + ASSERT_EQ(3, shape.rank()); + ASSERT_EQ(4, shape.dim(0).value()); + ASSERT_EQ(3, shape.dim(1).value()); + ASSERT_EQ(1, shape.dim(2).value()); +} + +TEST(ShapeRuleTest, squeeze_all) +{ + luci::CircleInput input; + luci::CircleSqueeze squeeze; + + input.shape({1, 4, 3, 1}); + input.shape_status(luci::ShapeStatus::VALID); + + squeeze.input(&input); + squeeze.squeeze_dims({}); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&squeeze, shape)); + ASSERT_EQ(2, shape.rank()); + ASSERT_EQ(4, shape.dim(0).value()); + ASSERT_EQ(3, shape.dim(1).value()); +} + +TEST(CloneNodeTest, clone_Squeeze) +{ + auto g = loco::make_graph(); + auto node_squ = g->nodes()->create<luci::CircleSqueeze>(); + node_squ->squeeze_dims({2, 3}); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_squ, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_squ = dynamic_cast<luci::CircleSqueeze *>(cloned); + ASSERT_NE(nullptr, cloned_squ); + ASSERT_EQ(node_squ->squeeze_dims().size(), cloned_squ->squeeze_dims().size()); + for (size_t s = 0; s < node_squ->squeeze_dims().size(); ++s) + ASSERT_EQ(node_squ->squeeze_dims().at(s), cloned_squ->squeeze_dims().at(s)); +} diff --git a/compiler/luci/service/src/Nodes/CircleStridedSlice.cpp b/compiler/luci/service/src/Nodes/CircleStridedSlice.cpp new file mode 100644 index 000000000..c4d199316 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleStridedSlice.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleStridedSlice *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleStridedSlice>(); + if (cloned != nullptr) + { + cloned->begin_mask(node->begin_mask()); + cloned->end_mask(node->end_mask()); + cloned->ellipsis_mask(node->ellipsis_mask()); + cloned->new_axis_mask(node->new_axis_mask()); + cloned->shrink_axis_mask(node->shrink_axis_mask()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleStridedSlice.test.cpp b/compiler/luci/service/src/Nodes/CircleStridedSlice.test.cpp new file mode 100644 index 000000000..d633f3022 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleStridedSlice.test.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_StridedSlice) +{ + auto g = loco::make_graph(); + auto node_ss = g->nodes()->create<luci::CircleStridedSlice>(); + node_ss->begin_mask(1); + node_ss->end_mask(2); + node_ss->ellipsis_mask(3); + node_ss->new_axis_mask(4); + node_ss->shrink_axis_mask(5); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_ss, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_ss = dynamic_cast<luci::CircleStridedSlice *>(cloned); + ASSERT_NE(nullptr, cloned_ss); + ASSERT_EQ(node_ss->begin_mask(), cloned_ss->begin_mask()); + ASSERT_EQ(node_ss->end_mask(), cloned_ss->end_mask()); + ASSERT_EQ(node_ss->ellipsis_mask(), cloned_ss->ellipsis_mask()); + ASSERT_EQ(node_ss->new_axis_mask(), cloned_ss->new_axis_mask()); + ASSERT_EQ(node_ss->shrink_axis_mask(), cloned_ss->shrink_axis_mask()); +} diff --git a/compiler/luci/service/src/Nodes/CircleSub.cpp b/compiler/luci/service/src/Nodes/CircleSub.cpp new file mode 100644 index 000000000..fb4bab19a --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSub.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleSub *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleSub>(); + if (cloned != nullptr) + cloned->fusedActivationFunction(node->fusedActivationFunction()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSub.test.cpp b/compiler/luci/service/src/Nodes/CircleSub.test.cpp new file mode 100644 index 000000000..e6bd7b8ff --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSub.test.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Sub) +{ + auto g = loco::make_graph(); + auto node_sub = g->nodes()->create<luci::CircleSub>(); + node_sub->fusedActivationFunction(luci::FusedActFunc::RELU); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_sub, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_sub = dynamic_cast<luci::CircleSub *>(cloned); + ASSERT_NE(nullptr, cloned_sub); + ASSERT_EQ(node_sub->fusedActivationFunction(), cloned_sub->fusedActivationFunction()); +} + +TEST(CloneNodeTest, clone_Sub_NEG) +{ + auto g = loco::make_graph(); + auto node_sub = g->nodes()->create<luci::CircleSub>(); + node_sub->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_sub, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleSum.cpp b/compiler/luci/service/src/Nodes/CircleSum.cpp index 9ef90e8e0..29e6ee5f1 100644 --- a/compiler/luci/service/src/Nodes/CircleSum.cpp +++ b/compiler/luci/service/src/Nodes/CircleSum.cpp @@ -1,11 +1,11 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,15 +14,17 @@ * limitations under the License. */ -#include <luci/Service/CircleShapeSignatureInference.h> +#include "CircleCloneNode.h" namespace luci { -ShapeSignature ssinf::Algorithm::visit(const luci::CircleSum *node) +luci::CircleNode *CloneNode::visit(const luci::CircleSum *node) { - return legalized_signature( - reduced_signature(node->input(), node->reduction_indices(), node->keep_dims())); + auto *cloned = _graph->nodes()->create<luci::CircleSum>(); + if (cloned != nullptr) + cloned->keep_dims(node->keep_dims()); + return cloned; } } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSum.test.cpp b/compiler/luci/service/src/Nodes/CircleSum.test.cpp new file mode 100644 index 000000000..aa1b0d128 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSum.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Sum) +{ + auto g = loco::make_graph(); + auto node_sum = g->nodes()->create<luci::CircleSum>(); + node_sum->keep_dims(true); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_sum, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_sum = dynamic_cast<luci::CircleSum *>(cloned); + ASSERT_NE(nullptr, cloned_sum); + ASSERT_EQ(node_sum->keep_dims(), cloned_sum->keep_dims()); +} diff --git a/compiler/luci/service/src/Nodes/CircleTanh.cpp b/compiler/luci/service/src/Nodes/CircleTanh.cpp new file mode 100644 index 000000000..9cb35932f --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleTanh.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleTanh *) +{ + return _graph->nodes()->create<luci::CircleTanh>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleTanh.test.cpp b/compiler/luci/service/src/Nodes/CircleTanh.test.cpp new file mode 100644 index 000000000..0215b42ca --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleTanh.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Tanh) +{ + auto g = loco::make_graph(); + auto node_tanh = g->nodes()->create<luci::CircleTanh>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_tanh, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_tanh = dynamic_cast<luci::CircleTanh *>(cloned); + ASSERT_NE(nullptr, cloned_tanh); +} diff --git a/compiler/luci/service/src/Nodes/CircleTile.cpp b/compiler/luci/service/src/Nodes/CircleTile.cpp new file mode 100644 index 000000000..21c32e021 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleTile.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleTile *) +{ + return _graph->nodes()->create<luci::CircleTile>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleTile.test.cpp b/compiler/luci/service/src/Nodes/CircleTile.test.cpp new file mode 100644 index 000000000..089c86ccb --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleTile.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Tile) +{ + auto g = loco::make_graph(); + auto node_tile = g->nodes()->create<luci::CircleTile>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_tile, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_tile = dynamic_cast<luci::CircleTile *>(cloned); + ASSERT_NE(nullptr, cloned_tile); +} diff --git a/compiler/luci/service/src/Nodes/CircleTopKV2.cpp b/compiler/luci/service/src/Nodes/CircleTopKV2.cpp new file mode 100644 index 000000000..e940c03dd --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleTopKV2.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleTopKV2 *) +{ + return _graph->nodes()->create<luci::CircleTopKV2>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleTopKV2.test.cpp b/compiler/luci/service/src/Nodes/CircleTopKV2.test.cpp new file mode 100644 index 000000000..7f68a408d --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleTopKV2.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_TopKV2) +{ + auto g = loco::make_graph(); + auto node_top = g->nodes()->create<luci::CircleTopKV2>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_top, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_top = dynamic_cast<luci::CircleTopKV2 *>(cloned); + ASSERT_NE(nullptr, cloned_top); +} diff --git a/compiler/luci/service/src/Nodes/CircleTopKV2Out.cpp b/compiler/luci/service/src/Nodes/CircleTopKV2Out.cpp new file mode 100644 index 000000000..5c13f2be1 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleTopKV2Out.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleTopKV2Out *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleTopKV2Out>(); + if (cloned != nullptr) + cloned->index(node->index()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleTopKV2Out.test.cpp b/compiler/luci/service/src/Nodes/CircleTopKV2Out.test.cpp new file mode 100644 index 000000000..cfba61f10 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleTopKV2Out.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_TopKV2Out) +{ + auto g = loco::make_graph(); + auto node_tout = g->nodes()->create<luci::CircleTopKV2Out>(); + node_tout->index(1); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_tout, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_tout = dynamic_cast<luci::CircleTopKV2Out *>(cloned); + ASSERT_NE(nullptr, cloned_tout); + ASSERT_EQ(node_tout->index(), cloned_tout->index()); +} diff --git a/compiler/luci/service/src/Nodes/CircleTranspose.cpp b/compiler/luci/service/src/Nodes/CircleTranspose.cpp new file mode 100644 index 000000000..81db55269 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleTranspose.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleTranspose *) +{ + return _graph->nodes()->create<luci::CircleTranspose>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleTranspose.test.cpp b/compiler/luci/service/src/Nodes/CircleTranspose.test.cpp new file mode 100644 index 000000000..9447d1a5b --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleTranspose.test.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Service/CircleShapeInference.h> + +#include <loco/IR/TensorShape.h> + +#include <gtest/gtest.h> + +TEST(ShapeRuleTest, transpose_simple) +{ + luci::CircleInput input; + luci::CircleConst perm; + luci::CircleTranspose transpose; + + input.shape({3, 8, 1}); + input.shape_status(luci::ShapeStatus::VALID); + + perm.dtype(loco::DataType::S32); + perm.rank(1); + perm.dim(0).set(3); + perm.size<loco::DataType::S32>(3); + perm.at<loco::DataType::S32>(0) = 1; + perm.at<loco::DataType::S32>(1) = 2; + perm.at<loco::DataType::S32>(2) = 0; + perm.shape_status(luci::ShapeStatus::VALID); + + transpose.a(&input); + transpose.perm(&perm); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&transpose, shape)); + ASSERT_EQ(3, shape.rank()); + ASSERT_EQ(8, shape.dim(0).value()); + ASSERT_EQ(1, shape.dim(1).value()); + ASSERT_EQ(3, shape.dim(2).value()); +} + +TEST(CloneNodeTest, clone_Transpose) +{ + auto g = loco::make_graph(); + auto node_tr = g->nodes()->create<luci::CircleTranspose>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_tr, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_tr = dynamic_cast<luci::CircleTranspose *>(cloned); + ASSERT_NE(nullptr, cloned_tr); +} diff --git a/compiler/luci/service/src/Nodes/CircleTransposeConv.cpp b/compiler/luci/service/src/Nodes/CircleTransposeConv.cpp new file mode 100644 index 000000000..1fe41bdb2 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleTransposeConv.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleTransposeConv *node) +{ + if (node->padding() == luci::Padding::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleTransposeConv>(); + if (cloned != nullptr) + { + cloned->padding(node->padding()); + cloned->stride()->h(node->stride()->h()); + cloned->stride()->w(node->stride()->w()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp b/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp new file mode 100644 index 000000000..29a656c03 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_TransposeConv) +{ + auto g = loco::make_graph(); + auto node_trconv = g->nodes()->create<luci::CircleTransposeConv>(); + node_trconv->padding(luci::Padding::SAME); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_trconv, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_trconv = dynamic_cast<luci::CircleTransposeConv *>(cloned); + ASSERT_NE(nullptr, cloned_trconv); + ASSERT_EQ(node_trconv->padding(), cloned_trconv->padding()); +} + +TEST(CloneNodeTest, clone_TransposeConv_padding_NEG) +{ + auto g = loco::make_graph(); + auto node_trconv = g->nodes()->create<luci::CircleTransposeConv>(); + node_trconv->padding(luci::Padding::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_trconv, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp b/compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp new file mode 100644 index 000000000..12205f3b0 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleUnidirectionalSequenceLSTM *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleUnidirectionalSequenceLSTM>(); + if (cloned != nullptr) + { + cloned->fusedActivationFunction(node->fusedActivationFunction()); + cloned->cell_clip(node->cell_clip()); + cloned->proj_clip(node->proj_clip()); + cloned->time_major(node->time_major()); + cloned->asymmetric_quantize_inputs(node->asymmetric_quantize_inputs()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.test.cpp b/compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.test.cpp new file mode 100644 index 000000000..c3816ab27 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.test.cpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_UnidirectionalSequenceLSTM) +{ + auto g = loco::make_graph(); + auto node_uslstm = g->nodes()->create<luci::CircleUnidirectionalSequenceLSTM>(); + node_uslstm->fusedActivationFunction(luci::FusedActFunc::RELU); + node_uslstm->cell_clip(1.1f); + node_uslstm->proj_clip(2.2f); + node_uslstm->time_major(true); + node_uslstm->asymmetric_quantize_inputs(true); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_uslstm, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_uslstm = dynamic_cast<luci::CircleUnidirectionalSequenceLSTM *>(cloned); + ASSERT_NE(nullptr, cloned_uslstm); + ASSERT_EQ(node_uslstm->fusedActivationFunction(), cloned_uslstm->fusedActivationFunction()); + ASSERT_EQ(node_uslstm->cell_clip(), cloned_uslstm->cell_clip()); + ASSERT_EQ(node_uslstm->proj_clip(), cloned_uslstm->proj_clip()); + ASSERT_EQ(node_uslstm->time_major(), cloned_uslstm->time_major()); + ASSERT_EQ(node_uslstm->asymmetric_quantize_inputs(), cloned_uslstm->asymmetric_quantize_inputs()); +} + +TEST(CloneNodeTest, clone_UnidirectionalSequenceLSTM_NEG) +{ + auto g = loco::make_graph(); + auto node_uslstm = g->nodes()->create<luci::CircleUnidirectionalSequenceLSTM>(); + node_uslstm->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_uslstm, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleUnique.cpp b/compiler/luci/service/src/Nodes/CircleUnique.cpp new file mode 100644 index 000000000..bde2ea0dc --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleUnique.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleUnique *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleUnique>(); + if (cloned != nullptr) + cloned->idx_out_type(node->idx_out_type()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleUnique.test.cpp b/compiler/luci/service/src/Nodes/CircleUnique.test.cpp new file mode 100644 index 000000000..a8ff9eade --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleUnique.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Unique) +{ + auto g = loco::make_graph(); + auto node_uniq = g->nodes()->create<luci::CircleUnique>(); + node_uniq->idx_out_type(loco::DataType::S32); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_uniq, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_uniq = dynamic_cast<luci::CircleUnique *>(cloned); + ASSERT_NE(nullptr, cloned_uniq); + ASSERT_EQ(node_uniq->idx_out_type(), cloned_uniq->idx_out_type()); +} diff --git a/compiler/luci/service/src/Nodes/CircleUniqueOut.cpp b/compiler/luci/service/src/Nodes/CircleUniqueOut.cpp new file mode 100644 index 000000000..30093f9db --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleUniqueOut.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleUniqueOut *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleUniqueOut>(); + if (cloned != nullptr) + cloned->index(node->index()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleUniqueOut.test.cpp b/compiler/luci/service/src/Nodes/CircleUniqueOut.test.cpp new file mode 100644 index 000000000..780ad4b78 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleUniqueOut.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_UniqueOut) +{ + auto g = loco::make_graph(); + auto node_uout = g->nodes()->create<luci::CircleUniqueOut>(); + node_uout->index(1); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_uout, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_uout = dynamic_cast<luci::CircleUniqueOut *>(cloned); + ASSERT_NE(nullptr, cloned_uout); + ASSERT_EQ(node_uout->index(), cloned_uout->index()); +} diff --git a/compiler/luci/service/src/Nodes/CircleUnpack.cpp b/compiler/luci/service/src/Nodes/CircleUnpack.cpp new file mode 100644 index 000000000..f9d61c426 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleUnpack.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleUnpack *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleUnpack>(); + if (cloned != nullptr) + { + cloned->num(node->num()); + cloned->axis(node->axis()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleUnpack.test.cpp b/compiler/luci/service/src/Nodes/CircleUnpack.test.cpp new file mode 100644 index 000000000..6559a9276 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleUnpack.test.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Unpack) +{ + auto g = loco::make_graph(); + auto node_unp = g->nodes()->create<luci::CircleUnpack>(); + node_unp->num(1); + node_unp->axis(2); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_unp, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_unp = dynamic_cast<luci::CircleUnpack *>(cloned); + ASSERT_NE(nullptr, cloned_unp); + ASSERT_EQ(node_unp->num(), cloned_unp->num()); + ASSERT_EQ(node_unp->axis(), cloned_unp->axis()); +} diff --git a/compiler/luci/service/src/Nodes/CircleUnpackOut.cpp b/compiler/luci/service/src/Nodes/CircleUnpackOut.cpp new file mode 100644 index 000000000..342d5daca --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleUnpackOut.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleUnpackOut *node) +{ + auto *cloned = _graph->nodes()->create<luci::CircleUnpackOut>(); + if (cloned != nullptr) + cloned->index(node->index()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleUnpackOut.test.cpp b/compiler/luci/service/src/Nodes/CircleUnpackOut.test.cpp new file mode 100644 index 000000000..ec9bb974e --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleUnpackOut.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_UnpackOut) +{ + auto g = loco::make_graph(); + auto node_uout = g->nodes()->create<luci::CircleUnpackOut>(); + node_uout->index(1); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_uout, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_uout = dynamic_cast<luci::CircleUnpackOut *>(cloned); + ASSERT_NE(nullptr, cloned_uout); + ASSERT_EQ(node_uout->index(), cloned_uout->index()); +} diff --git a/compiler/luci/service/src/Nodes/CircleWhere.cpp b/compiler/luci/service/src/Nodes/CircleWhere.cpp new file mode 100644 index 000000000..73f4b64ac --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleWhere.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleWhere *) +{ + return _graph->nodes()->create<luci::CircleWhere>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleWhere.test.cpp b/compiler/luci/service/src/Nodes/CircleWhere.test.cpp new file mode 100644 index 000000000..352719d85 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleWhere.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Where) +{ + auto g = loco::make_graph(); + auto node_wh = g->nodes()->create<luci::CircleWhere>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_wh, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_wh = dynamic_cast<luci::CircleWhere *>(cloned); + ASSERT_NE(nullptr, cloned_wh); +} diff --git a/compiler/luci/service/src/Nodes/CircleZerosLike.cpp b/compiler/luci/service/src/Nodes/CircleZerosLike.cpp new file mode 100644 index 000000000..2ee455857 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleZerosLike.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleZerosLike *) +{ + return _graph->nodes()->create<luci::CircleZerosLike>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleZerosLike.test.cpp b/compiler/luci/service/src/Nodes/CircleZerosLike.test.cpp new file mode 100644 index 000000000..6e0a4b3be --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleZerosLike.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_ZerosLike) +{ + auto g = loco::make_graph(); + auto node_zl = g->nodes()->create<luci::CircleZerosLike>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_zl, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_zl = dynamic_cast<luci::CircleZerosLike *>(cloned); + ASSERT_NE(nullptr, cloned_zl); +} diff --git a/compiler/luci/service/src/ShapeDescription.cpp b/compiler/luci/service/src/ShapeDescription.cpp index 01a638f8f..adfb7e342 100644 --- a/compiler/luci/service/src/ShapeDescription.cpp +++ b/compiler/luci/service/src/ShapeDescription.cpp @@ -31,7 +31,7 @@ ShapeDescription to_shape_description(const luci::CircleNode *circle_node) res._dims.resize(circle_node->rank()); for (uint32_t i = 0; i < circle_node->rank(); ++i) - res._dims.at(i) = circle_node->dim(i).value(); + res._dims.at(i) = circle_node->dim(i).known() ? circle_node->dim(i).value() : -1; return res; } @@ -53,95 +53,12 @@ ShapeDescription to_shape_description(const loco::TensorShape &shape) return res; } -ShapeDescription to_shape_description(const loco::FeatureShape &shape) -{ - ShapeDescription res; - - res._rank_known = true; - - // T/F Lite encodes a feature map as a NHWC tensor - res._dims.resize(4); - res._dims.at(0) = shape.count().value(); - res._dims.at(1) = shape.height().value(); - res._dims.at(2) = shape.width().value(); - res._dims.at(3) = shape.depth().value(); - - return res; -} - -ShapeDescription to_shape_description(const loco::FilterShape &shape) -{ - ShapeDescription res; - - res._rank_known = true; - - // T/F Lite encodes a convolution filter as a NHWC tensor - res._dims.resize(4); - res._dims.at(0) = shape.count().value(); - res._dims.at(1) = shape.height().value(); - res._dims.at(2) = shape.width().value(); - res._dims.at(3) = shape.depth().value(); - - return res; -} - -ShapeDescription to_shape_description(const loco::DepthwiseFilterShape &shape) -{ - ShapeDescription res; - - res._rank_known = true; - - // T/F Lite encodes a depthwise convolution filter as a [1, H, W, C*M] tensor - res._dims.resize(4); - res._dims.at(0) = 1; - res._dims.at(1) = shape.height().value(); - res._dims.at(2) = shape.width().value(); - res._dims.at(3) = shape.depth().value() * shape.multiplier().value(); - - return res; -} - -ShapeDescription to_shape_description(const loco::BiasShape &shape) -{ - ShapeDescription res; - - res._rank_known = true; - - res._dims.resize(1); - res._dims.at(0) = shape.length().value(); - - return res; -} - -ShapeDescription to_shape_description(const loco::MatrixShape &shape) -{ - ShapeDescription res; - - res._rank_known = true; - - res._dims.resize(2); - res._dims.at(0) = shape.height().value(); - res._dims.at(1) = shape.width().value(); - - return res; -} - ShapeDescription to_shape_description(const loco::NodeShape &shape) { switch (shape.domain()) { case loco::Domain::Tensor: return to_shape_description(shape.as<loco::TensorShape>()); - case loco::Domain::Feature: - return to_shape_description(shape.as<loco::FeatureShape>()); - case loco::Domain::Filter: - return to_shape_description(shape.as<loco::FilterShape>()); - case loco::Domain::DepthwiseFilter: - return to_shape_description(shape.as<loco::DepthwiseFilterShape>()); - case loco::Domain::Bias: - return to_shape_description(shape.as<loco::BiasShape>()); - case loco::Domain::Matrix: - return to_shape_description(shape.as<loco::MatrixShape>()); default: break; } diff --git a/compiler/luci/service/src/ShapeDescription.test.cpp b/compiler/luci/service/src/ShapeDescription.test.cpp new file mode 100644 index 000000000..6e53aac75 --- /dev/null +++ b/compiler/luci/service/src/ShapeDescription.test.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/ShapeDescription.h" + +#include <luci/IR/CircleNode.h> +#include <luci/IR/Nodes/CircleConst.h> + +#include <gtest/gtest.h> + +TEST(ShapeDescriptionTest, CircleNode) +{ + // Use CircleConst as CircleNode + luci::CircleConst circle_const; + circle_const.shape({1, 2, 3, 4}); + + auto sd = luci::to_shape_description(&circle_const); + + ASSERT_EQ(4, sd._dims.size()); + ASSERT_EQ(1, sd._dims.at(0)); + ASSERT_TRUE(sd._rank_known); +} + +TEST(ShapeDescriptionTest, TensorShape) +{ + loco::TensorShape tensor_shape{1, 2, 3, 4}; + loco::NodeShape node_shape(tensor_shape); + + auto sd = luci::to_shape_description(node_shape); + + ASSERT_EQ(4, sd._dims.size()); + ASSERT_EQ(1, sd._dims.at(0)); + ASSERT_TRUE(sd._rank_known); +} + +TEST(ShapeDescriptionTest, BiasShape_NEG) +{ + loco::BiasShape bias_shape; + bias_shape.length() = 1; + loco::NodeShape node_shape(bias_shape); + + EXPECT_THROW(luci::to_shape_description(node_shape), std::exception); +} diff --git a/compiler/luci/service/src/ShapeInfer_StridedSlice.cpp b/compiler/luci/service/src/ShapeInfer_StridedSlice.cpp index 341201148..c5864f938 100644 --- a/compiler/luci/service/src/ShapeInfer_StridedSlice.cpp +++ b/compiler/luci/service/src/ShapeInfer_StridedSlice.cpp @@ -17,12 +17,12 @@ #include "ShapeInfer_StridedSlice.h" #include "Check.h" +#include "CircleShapeInferenceHelper.h" #include <luci/IR/CircleNode.h> #include <loco/IR/DataType.h> #include <loco/IR/NodeShape.h> #include <oops/InternalExn.h> -#include <loco/Service/ShapeInference.h> #include <cmath> #include <cstdint> @@ -245,7 +245,7 @@ loco::TensorShape infer_output_shape(const CircleStridedSlice *node) assert(node->new_axis_mask() == 0); auto op_params = BuildStridedSliceParams(node); - loco::TensorShape input_shape = loco::shape_get(input_node).as<loco::TensorShape>(); + loco::TensorShape input_shape = luci::shape_get(input_node).as<loco::TensorShape>(); uint32_t num_input_axes = input_shape.rank(); assert(begin_node->size<S32>() <= num_input_axes); diff --git a/compiler/luci/service/src/Validate.cpp b/compiler/luci/service/src/Validate.cpp index 3f732b6fe..7ed14c356 100644 --- a/compiler/luci/service/src/Validate.cpp +++ b/compiler/luci/service/src/Validate.cpp @@ -20,10 +20,9 @@ #include <luci/Log.h> #include <loco/IR/NodeShape.h> -#include <loco/Service/ShapeInference.h> -#include <loco/Service/TypeInference.h> #include <cassert> +#include <unordered_map> #include <vector> namespace @@ -36,7 +35,11 @@ std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape { if (r) os << ","; - os << tensor_shape.dim(r).value(); + + if (tensor_shape.dim(r).known()) + os << tensor_shape.dim(r).value(); + else + os << "?"; } os << "]"; return os; @@ -49,7 +52,11 @@ std::ostream &operator<<(std::ostream &os, const luci::CircleNode *circle_node) { if (r) os << ","; - os << circle_node->dim(r).value(); + + if (circle_node->dim(r).known()) + os << circle_node->dim(r).value(); + else + os << "?"; } os << "]"; return os; @@ -99,10 +106,24 @@ bool validate_shape_dtype(loco::Graph *g) auto go_tensor_shape = graph_out->shape(); assert(go_tensor_shape); + // NOTE Even if shape of graph output is [] (which means "shape inference was impossible") + // but shape of CircleNode is not, it can be valid case because shape inference + // algorithm of CircleNode may be upgraded than before. The opposite is possible either. + // If such cases are appeared, following validation code should be fixed. bool is_shape_valid = (circle_node->rank() == go_tensor_shape->rank()); for (uint32_t i = 0; is_shape_valid && i < circle_node->rank(); ++i) - if (circle_node->dim(i).value() != go_tensor_shape->dim(i).value()) + { + if (!circle_node->dim(i).known() || !go_tensor_shape->dim(i).known()) + { + // If at least one of two dimensions is unknown, + // the unknown dimension can accept any value. + INFO(l) << "Unknown dimension is matched with known dimension" << std::endl; + } + else if (circle_node->dim(i).value() != go_tensor_shape->dim(i).value()) + { is_shape_valid = false; + } + } if (is_shape_valid == false) { @@ -124,72 +145,62 @@ bool validate_shape_dtype(loco::Graph *g) return true; } -bool validate_shape_signature(loco::Graph *g) -{ - LOGGER(l); - - for (auto node : loco::postorder_traversal(loco::output_nodes(g))) - { - auto circle_node = loco::must_cast<luci::CircleNode *>(node); - const auto shape_signature = circle_node->shape_signature(); +} // namespace - if (shape_signature.rank() == 0) - continue; +namespace luci +{ - // Rank of shape and shape signature should be same - if (circle_node->rank() != shape_signature.rank()) - { - INFO(l) << "[luci] Rank of shape signature for " << circle_node->name() << " do not match" - << std::endl; - return false; - } +bool validate(loco::Graph *g) +{ + if (!loco::valid(g)) + return false; - bool has_unknown = false; + if (!validate_shape_dtype(g)) + return false; - // If shape siganture is not -1, dimension value should be same - for (uint32_t d = 0; d < shape_signature.rank(); ++d) - { - if (shape_signature.dim(d) != -1 && - shape_signature.dim(d) != (int32_t)(circle_node->dim(d).value())) - { - INFO(l) << "[luci] Dimension " << d << "of shape signature for " << circle_node->name() - << " do not match" << std::endl; - return false; - } + // TODO add more validation - if (shape_signature.dim(d) == -1) - has_unknown = true; - } + return true; +} - // Shape signature should have at least one -1 value. - if (!has_unknown) - { - INFO(l) << "[luci] Shape signature in " << circle_node->name() - << " do not have unknown dimension" << std::endl; +bool validate_name(loco::Graph *g) +{ + auto nodes = g->nodes(); + for (uint32_t n = 0; n < nodes->size(); ++n) + { + auto node = loco::must_cast<luci::CircleNode *>(nodes->at(n)); + auto name = node->name(); + if (name.empty()) return false; - } } return true; } -} // namespace - -namespace luci +bool validate_unique_name(luci::Module *m) { + std::unordered_map<std::string, bool> names_col; -bool validate(loco::Graph *g) -{ - if (!loco::valid(g)) - return false; - - if (!validate_shape_dtype(g)) - return false; - - if (!validate_shape_signature(g)) - return false; + for (size_t g = 0; g < m->size(); ++g) + { + auto graph = m->graph(g); + auto nodes = graph->nodes(); + for (uint32_t n = 0; n < nodes->size(); ++n) + { + auto node = loco::must_cast<luci::CircleNode *>(nodes->at(n)); + // skip CircleOutput as it may have same name with from() node + auto output = dynamic_cast<luci::CircleOutput *>(node); + if (output != nullptr) + continue; + + auto name = node->name(); + auto it = names_col.find(name); + if (it != names_col.end()) + return false; - // TODO add more validation + names_col[name] = true; + } + } return true; } diff --git a/compiler/luci/service/src/Validate.test.cpp b/compiler/luci/service/src/Validate.test.cpp new file mode 100644 index 000000000..8ce6d895b --- /dev/null +++ b/compiler/luci/service/src/Validate.test.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/Validate.h" + +#include <luci/test/TestIOGraph.h> + +#include <luci/IR/Nodes/CircleAdd.h> +#include <luci/IR/Nodes/CircleSqrt.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class SqrtGraphlet +{ +public: + SqrtGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 input_shape) + { + _sqrt = g->nodes()->create<luci::CircleSqrt>(); + _sqrt->dtype(loco::DataType::S32); + _sqrt->name("sqrt"); + } + +protected: + luci::CircleSqrt *_sqrt = nullptr; +}; + +class SqrtGraph : public TestIOGraph, public SqrtGraphlet +{ +public: + SqrtGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + SqrtGraphlet::init(g(), shape); + + _sqrt->x(input()); + + output()->from(_sqrt); + + // set output name to _sqrt: CircleOutput may have duplicate name + output()->name(_sqrt->name()); + } +}; + +class Sqrt2xGraphlet +{ +public: + Sqrt2xGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 input_shape) + { + _sqrt1 = g->nodes()->create<luci::CircleSqrt>(); + _sqrt1->dtype(loco::DataType::S32); + _sqrt1->name("sqrt"); + + _sqrt2 = g->nodes()->create<luci::CircleSqrt>(); + _sqrt2->dtype(loco::DataType::S32); + _sqrt2->name("sqrt"); + } + +protected: + luci::CircleSqrt *_sqrt1 = nullptr; + luci::CircleSqrt *_sqrt2 = nullptr; +}; + +class Sqrt2xGraph : public TestIOGraph, public Sqrt2xGraphlet +{ +public: + Sqrt2xGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + Sqrt2xGraphlet::init(g(), shape); + + _sqrt1->x(input()); + + _sqrt2->x(_sqrt1); + + output()->from(_sqrt2); + } +}; + +} // namespace + +TEST(ValidateTest, non_empty_name) +{ + SqrtGraph g; + g.init({3, 3}); + + ASSERT_TRUE(luci::validate_name(g.g())); +} + +TEST(ValidateTest, unique_name) +{ + luci::Module module; + + SqrtGraph g; + g.init({3, 3}); + g.transfer_to(&module); + + ASSERT_TRUE(luci::validate_unique_name(&module)); +} + +TEST(ValidateTest, unique_name_NEG) +{ + luci::Module module; + + Sqrt2xGraph g; + g.init({3, 3}); + g.transfer_to(&module); + + ASSERT_FALSE(luci::validate_unique_name(&module)); +} diff --git a/compiler/luci/tester/CMakeLists.txt b/compiler/luci/tester/CMakeLists.txt index 3ac06ef3a..13aab11e7 100644 --- a/compiler/luci/tester/CMakeLists.txt +++ b/compiler/luci/tester/CMakeLists.txt @@ -6,6 +6,7 @@ TargetRequire_Return(${REQUIRED_TARGETS}) set(SRCS_READ_TESTER src/ReadTester.cpp + src/ReadModule.cpp ) add_executable(luci_readtester "${SRCS_READ_TESTER}") @@ -18,6 +19,7 @@ target_link_libraries(luci_readtester PRIVATE safemain) set(SRCS_WRITE_TESTER src/WriteTester.cpp + src/ReadModule.cpp ) add_executable(luci_writetester "${SRCS_WRITE_TESTER}") @@ -28,3 +30,22 @@ target_link_libraries(luci_writetester PRIVATE luci_export) target_link_libraries(luci_writetester PRIVATE foder) target_link_libraries(luci_writetester PRIVATE oops) target_link_libraries(luci_writetester PRIVATE safemain) + +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +nnas_find_package(GTest REQUIRED) + +GTest_AddTest(luci_readtester_test src/ReadTester.test.cpp ${SRCS_READ_TESTER}) +target_link_libraries(luci_readtester_test luci_import) +target_link_libraries(luci_readtester_test luci_service) +target_link_libraries(luci_readtester_test luci_pass) +target_link_libraries(luci_readtester_test foder) + +GTest_AddTest(luci_writetester_test src/WriteTester.test.cpp ${SRCS_WRITE_TESTER}) +target_link_libraries(luci_writetester_test luci_import) +target_link_libraries(luci_writetester_test luci_service) +target_link_libraries(luci_writetester_test luci_pass) +target_link_libraries(luci_writetester_test luci_export) +target_link_libraries(luci_writetester_test foder) diff --git a/compiler/luci/tester/src/ReadModule.cpp b/compiler/luci/tester/src/ReadModule.cpp new file mode 100644 index 000000000..87c1233f0 --- /dev/null +++ b/compiler/luci/tester/src/ReadModule.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ReadModule.h" + +#include <luci/Pass/CircleShapeInferencePass.h> +#include <luci/Pass/CircleTypeInferencePass.h> +#include <luci/Service/Validate.h> + +#include <logo/Phase.h> + +#include <iostream> +#include <string> +#include <vector> + +std::unique_ptr<luci::Module> ReadModule(std::string &input_path) +{ + // Load model from the file + foder::FileLoader file_loader{input_path}; + std::vector<char> model_data = file_loader.load(); + const circle::Model *circle_model = circle::GetModel(model_data.data()); + if (circle_model == nullptr) + { + std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl; + return nullptr; + } + + luci::Importer importer; + auto module = importer.importModule(circle_model); + assert(module->size() > 0); + + for (size_t g = 0; g < module->size(); ++g) + { + auto graph = module->graph(g); + if (graph == nullptr) + return nullptr; + + { + logo::Phase phase; + + phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); + phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); + + logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{graph}; + phase_runner.run(phase); + } + + if (!luci::validate(graph)) + return nullptr; + } + return module; +} diff --git a/compiler/luci/tester/src/ReadModule.h b/compiler/luci/tester/src/ReadModule.h new file mode 100644 index 000000000..dfa9bad6b --- /dev/null +++ b/compiler/luci/tester/src/ReadModule.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_TESTER_READ_MODULE_H__ +#define __LUCI_TESTER_READ_MODULE_H__ + +#include <luci/Importer.h> +#include <foder/FileLoader.h> + +#include <memory> +#include <string> + +std::unique_ptr<luci::Module> ReadModule(std::string &input_path); + +#endif // __LUCI_TESTER_READ_MODULE_H__ diff --git a/compiler/luci/tester/src/ReadTester.cpp b/compiler/luci/tester/src/ReadTester.cpp index f270a232c..864343e43 100644 --- a/compiler/luci/tester/src/ReadTester.cpp +++ b/compiler/luci/tester/src/ReadTester.cpp @@ -14,18 +14,9 @@ * limitations under the License. */ -#include <foder/FileLoader.h> - -#include <luci/Importer.h> -#include <luci/Service/Validate.h> -#include <luci/Pass/ShapeInferencePass.h> -#include <luci/Pass/TypeInferencePass.h> - -// Following passes will be removed after refactoring is finished -#include <luci/Pass/MigrateLegacyShapeDtypePass.h> +#include "ReadModule.h" #include <iostream> -#include <map> #include <string> namespace @@ -68,45 +59,9 @@ int entry(int argc, char **argv) std::cout << "[INFO] Circle is '" << input_path << "'" << std::endl; - // Load model from the file - foder::FileLoader file_loader{input_path}; - std::vector<char> model_data = file_loader.load(); - const circle::Model *circle_model = circle::GetModel(model_data.data()); - if (circle_model == nullptr) - { - std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl; + auto module = ReadModule(input_path); + if (module == nullptr) return EXIT_FAILURE; - } - - luci::Importer importer; - auto module = importer.importModule(circle_model); - assert(module->size() > 0); - for (size_t g = 0; g < module->size(); ++g) - { - auto graph = module->graph(g); - if (graph == nullptr) - return 255; - - { - luci::ShapeInferencePass pass; - while (pass.run(graph) == true) - ; - } - { - luci::TypeInferencePass pass; - while (pass.run(graph) == true) - ; - } - { - // This pass will be removed after refactoring is finished - luci::MigrateLegacyShapeDtypePass pass; - while (pass.run(graph) == true) - ; - } - - if (!luci::validate(graph)) - return 255; - } return 0; } diff --git a/compiler/luci/tester/src/ReadTester.test.cpp b/compiler/luci/tester/src/ReadTester.test.cpp new file mode 100644 index 000000000..f3850d517 --- /dev/null +++ b/compiler/luci/tester/src/ReadTester.test.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> + +// From ReadTester.cpp +int entry(int argc, char **argv); + +TEST(ReadTesterTest, invalid_argc_NEG) +{ + char argv_1[20]; + strcpy(argv_1, "ReadTesterTest"); + + int argc = 1; + char *argv[] = {argv_1}; + + ASSERT_NE(0, entry(argc, argv)); +} + +TEST(ReadTesterTest, invalid_file_NEG) +{ + char argv_1[20], argv_2[20]; + strcpy(argv_1, "ReadTesterTest"); + strcpy(argv_2, "not_a_file"); + + int argc = 2; + char *argv[] = {argv_1, argv_2}; + + EXPECT_THROW(entry(argc, argv), std::runtime_error); +} diff --git a/compiler/luci/tester/src/WriteTester.cpp b/compiler/luci/tester/src/WriteTester.cpp index 9a6e8de05..0d3a1efa2 100644 --- a/compiler/luci/tester/src/WriteTester.cpp +++ b/compiler/luci/tester/src/WriteTester.cpp @@ -14,21 +14,13 @@ * limitations under the License. */ -#include <foder/FileLoader.h> +#include "ReadModule.h" -#include <luci/Importer.h> -#include <luci/Pass/ShapeInferencePass.h> -#include <luci/Pass/TypeInferencePass.h> -#include <luci/Service/Validate.h> #include <luci/CircleExporter.h> #include <oops/InternalExn.h> -// Following passes will be removed after refactoring is finished -#include <luci/Pass/MigrateLegacyShapeDtypePass.h> - #include <fstream> #include <iostream> -#include <map> #include <string> namespace @@ -51,12 +43,12 @@ struct CircleExpContract : public luci::CircleExporter::Contract { public: CircleExpContract(loco::Graph *graph, const std::string &filename) - : _graph(graph), _filepath(filename) + : _graph(graph), _filepath(filename) { // NOTHING TO DO } CircleExpContract(luci::Module *module, const std::string &filename) - : _module(module), _filepath(filename) + : _module(module), _filepath(filename) { // NOTHING TO DO } @@ -111,47 +103,9 @@ int entry(int argc, char **argv) std::cout << "[INFO] Circle from '" << input_path << "' to '" << output_path << "'" << std::endl; - // Load model from the file - foder::FileLoader file_loader{input_path}; - std::vector<char> model_data = file_loader.load(); - const circle::Model *circle_model = circle::GetModel(model_data.data()); - if (circle_model == nullptr) - { - std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl; + auto module = ReadModule(input_path); + if (module == nullptr) return EXIT_FAILURE; - } - - // Import from input Circle file - luci::Importer importer; - auto module = importer.importModule(circle_model); - assert(module->size() > 0); - - for (size_t g = 0; g < module->size(); ++g) - { - auto graph = module->graph(g); - if (graph == nullptr) - return 255; - - { - luci::ShapeInferencePass pass; - while (pass.run(graph) == true) - ; - } - { - luci::TypeInferencePass pass; - while (pass.run(graph) == true) - ; - } - { - // This pass will be removed after refactoring is finished - luci::MigrateLegacyShapeDtypePass pass; - while (pass.run(graph) == true) - ; - } - - if (!luci::validate(graph)) - return 255; - } // Export to output Circle file luci::CircleExporter exporter; diff --git a/compiler/luci/tester/src/WriteTester.test.cpp b/compiler/luci/tester/src/WriteTester.test.cpp new file mode 100644 index 000000000..9d34c5f98 --- /dev/null +++ b/compiler/luci/tester/src/WriteTester.test.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> + +// From WriteTester.cpp +int entry(int argc, char **argv); + +TEST(WriteTesterTest, invalid_argc_NEG) +{ + char argv_1[20]; + strcpy(argv_1, "WriteTesterTest"); + + int argc = 1; + char *argv[] = {argv_1}; + + ASSERT_NE(0, entry(argc, argv)); +} + +TEST(WriteTesterTest, invalid_file_NEG) +{ + char argv_1[20], argv_2[20], argv_3[20]; + strcpy(argv_1, "WriteTesterTest"); + strcpy(argv_2, "not_a_file"); + strcpy(argv_3, "not_a_file"); + + int argc = 3; + char *argv[] = {argv_1, argv_2, argv_3}; + + EXPECT_THROW(entry(argc, argv), std::runtime_error); +} diff --git a/compiler/luci/testhelper/CMakeLists.txt b/compiler/luci/testhelper/CMakeLists.txt new file mode 100644 index 000000000..86aa66225 --- /dev/null +++ b/compiler/luci/testhelper/CMakeLists.txt @@ -0,0 +1,25 @@ +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +nnas_find_package(GTest REQUIRED) + +# NOTE we are using "*.test.cpp" NOT to be included in static analyzer tools + +# testhelper library itself +set(HELPER_SOURCE + src/TestShape.test.cpp + ) + +add_library(luci_testhelper STATIC ${HELPER_SOURCE}) +target_include_directories(luci_testhelper PRIVATE src) +target_include_directories(luci_testhelper PUBLIC include) +target_link_libraries(luci_testhelper luci_lang) + +# test for testhelper library +set(TESTER_SOURCE + src/TestIOGraph.test.cpp + ) + +GTest_AddTest(luci_testhelper_test ${TESTER_SOURCE}) +target_link_libraries(luci_testhelper_test luci_testhelper) diff --git a/compiler/luci/testhelper/README.md b/compiler/luci/testhelper/README.md new file mode 100644 index 000000000..6bdb92aa4 --- /dev/null +++ b/compiler/luci/testhelper/README.md @@ -0,0 +1,3 @@ +# luci-testhelper + +_luci-testhelper_ provides Helper classes for unit testing diff --git a/compiler/luci/testhelper/include/luci/test/TestIOGraph.h b/compiler/luci/testhelper/include/luci/test/TestIOGraph.h new file mode 100644 index 000000000..ae04f4dbc --- /dev/null +++ b/compiler/luci/testhelper/include/luci/test/TestIOGraph.h @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_TESTHELPER_TEST_IO_GRAPH_H__ +#define __LUCI_TESTHELPER_TEST_IO_GRAPH_H__ + +#include "TestShape.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/Module.h> + +#include <memory> +#include <stdexcept> + +namespace luci +{ +namespace test +{ + +/** + * @brief Graphlet with Inputs and loco::Graph for multiple inputs + * @note Every Graph will have Input(s) and Output(s) + * We put loco::Graph only in IsGraphlet not to declare separate + * class for loco::Graph + */ +template <unsigned N> class TestIsGraphlet +{ +public: + TestIsGraphlet() + { + for (uint32_t n = 0; n < N; ++n) + { + _graph_inputs[n] = nullptr; + _inputs[n] = nullptr; + } + _g = loco::make_graph(); + } + +public: + virtual void init(loco::Graph *g, const std::initializer_list<ShapeU32> shape_in) + { + if (shape_in.size() != N) + throw std::runtime_error("Failed to init TestIsGraphlet"); + + auto shpin = shape_in.begin(); + for (uint32_t n = 0; n < N; ++n) + { + _graph_inputs[n] = g->inputs()->create(); + + _inputs[n] = g->nodes()->create<luci::CircleInput>(); + _inputs[n]->shape(*shpin); + _inputs[n]->shape_status(luci::ShapeStatus::VALID); + _inputs[n]->dtype(loco::DataType::FLOAT32); + _inputs[n]->name("input_" + std::to_string(n)); + + _inputs[n]->index(_graph_inputs[n]->index()); + + auto input_shape = std::make_unique<loco::TensorShape>(); + set_shape_vector(input_shape.get(), *shpin); + _graph_inputs[n]->shape(std::move(input_shape)); + _graph_inputs[n]->dtype(loco::DataType::FLOAT32); + + shpin++; + } + } + +public: + loco::Graph *g(void) { return _g.get(); } + luci::CircleInput *input(int idx) { return _inputs[idx]; } + uint32_t num_inputs(void) { return N; } + +public: + void transfer_to(luci::Module *module) + { + // WARNING: after g is transfered, _graph_inputs, _inputs + // and _graph_outputs, _outputs in TestOsGraphlet will be invalid. + // arrays are not cleared as this is just helpers to unit tests + module->add(std::move(_g)); + } + +protected: + std::unique_ptr<loco::Graph> _g; + std::array<loco::GraphInput *, N> _graph_inputs; + std::array<luci::CircleInput *, N> _inputs; +}; + +/** + * @brief Graphlet with one Input + */ +class TestIGraphlet : public TestIsGraphlet<1> +{ +public: + virtual void init(loco::Graph *g, const ShapeU32 shape_in) + { + TestIsGraphlet<1>::init(g, {shape_in}); + } + + luci::CircleInput *input() { return _inputs[0]; } +}; + +/** + * @brief Graphlet with Outputs for multiple outputs + */ +template <unsigned N> class TestOsGraphlet +{ +public: + TestOsGraphlet() + { + for (uint32_t n = 0; n < N; ++n) + { + _graph_outputs[n] = nullptr; + _outputs[n] = nullptr; + } + } + +public: + virtual void init(loco::Graph *g, const std::initializer_list<ShapeU32> shape_out) + { + if (shape_out.size() != N) + throw std::runtime_error("Failed to init TestOsGraphlet"); + + auto shpout = shape_out.begin(); + for (uint32_t n = 0; n < N; ++n) + { + _graph_outputs[n] = g->outputs()->create(); + + _outputs[n] = g->nodes()->create<luci::CircleOutput>(); + _outputs[n]->shape(*shpout); + _outputs[n]->shape_status(luci::ShapeStatus::VALID); + _outputs[n]->dtype(loco::DataType::FLOAT32); + _outputs[n]->name("output_" + std::to_string(n)); + + _outputs[n]->index(_graph_outputs[n]->index()); + + auto output_shape = std::make_unique<loco::TensorShape>(); + set_shape_vector(output_shape.get(), *shpout); + _graph_outputs[n]->shape(std::move(output_shape)); + _graph_outputs[n]->dtype(loco::DataType::FLOAT32); + + shpout++; + } + } + +public: + luci::CircleOutput *output(int idx) { return _outputs[idx]; } + +protected: + std::array<loco::GraphOutput *, N> _graph_outputs; + std::array<luci::CircleOutput *, N> _outputs; +}; + +/** + * @brief Graphlet with one Output + */ +class TestOGraphlet : public TestOsGraphlet<1> +{ +public: + virtual void init(loco::Graph *g, const ShapeU32 shape_out) + { + TestOsGraphlet<1>::init(g, {shape_out}); + } + + luci::CircleOutput *output() { return _outputs[0]; } +}; + +/** + * @brief Graph with Input and Output + */ +class TestIOGraph : public TestIGraphlet, public TestOGraphlet +{ +public: + TestIOGraph() = default; + +public: + virtual void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIGraphlet::init(g(), shape_in); + TestOGraphlet::init(g(), shape_out); + } +}; + +} // namespace test +} // namespace luci + +#endif // __LUCI_TESTHELPER_TEST_IO_GRAPH_H__ diff --git a/compiler/luci/testhelper/include/luci/test/TestShape.h b/compiler/luci/testhelper/include/luci/test/TestShape.h new file mode 100644 index 000000000..1a5adf7d6 --- /dev/null +++ b/compiler/luci/testhelper/include/luci/test/TestShape.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_TESTHELPER_TEST_SHAPE_H__ +#define __LUCI_TESTHELPER_TEST_SHAPE_H__ + +#include <luci/IR/CircleNode.h> + +#include <initializer_list> + +namespace luci +{ +namespace test +{ + +using ShapeU32 = std::initializer_list<uint32_t>; +using ShapeI32 = std::initializer_list<int32_t>; + +void set_shape_vector(loco::TensorShape *shape, const ShapeU32 &values); +void set_shape_vector(luci::CircleConst *const_node, const ShapeI32 &values); + +uint32_t num_elements(const ShapeU32 shape); + +} // namespace test +} // namespace luci + +#endif // __LUCI_TESTHELPER_TEST_SHAPE_H__ diff --git a/compiler/luci/testhelper/src/TestIOGraph.test.cpp b/compiler/luci/testhelper/src/TestIOGraph.test.cpp new file mode 100644 index 000000000..8a7d1e060 --- /dev/null +++ b/compiler/luci/testhelper/src/TestIOGraph.test.cpp @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/test/TestIOGraph.h" + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class SqrtGraphlet +{ +public: + SqrtGraphlet() = default; + + void init(loco::Graph *g) + { + _sqrt = g->nodes()->create<luci::CircleSqrt>(); + _sqrt->name("sqrt"); + } + +protected: + luci::CircleSqrt *_sqrt = nullptr; +}; + +class AddGraphlet +{ +public: + AddGraphlet() = default; + + void init(loco::Graph *g) + { + _add = g->nodes()->create<luci::CircleAdd>(); + _add->name("add"); + } + +protected: + luci::CircleAdd *_add = nullptr; +}; + +class ConvGraphlet +{ +public: + ConvGraphlet() = default; + + void init(loco::Graph *g) + { + _conv = g->nodes()->create<luci::CircleConv2D>(); + _conv->name("conv"); + } + +protected: + luci::CircleConv2D *_conv = nullptr; +}; + +} // namespace + +namespace +{ + +class TestOfTestIOGraph : public TestIOGraph, public SqrtGraphlet +{ +public: + TestOfTestIOGraph() = default; + +public: + void init(void) + { + TestIOGraph::init({1}, {1}); + SqrtGraphlet::init(g()); + + _sqrt->x(input()); + + output()->from(_sqrt); + } +}; + +class TestOfTestI2OGraph : public TestIsGraphlet<2>, public TestOGraphlet, public AddGraphlet +{ +public: + TestOfTestI2OGraph() = default; + +public: + void init(void) + { + TestIsGraphlet<2>::init(g(), {{2, 3}, {2, 3}}); + TestOsGraphlet<1>::init(g(), {{2, 3}}); + AddGraphlet::init(g()); + + _add->x(input(0)); + _add->y(input(1)); + + output()->from(_add); + } +}; + +class TestOfTestI3OGraph : public TestIsGraphlet<3>, public TestOGraphlet, public ConvGraphlet +{ +public: + TestOfTestI3OGraph() = default; + +public: + void init(void) + { + TestIsGraphlet<3>::init(g(), {{2, 3, 3, 4}, {1, 1}, {4}}); + TestOsGraphlet<1>::init(g(), {{2, 3, 3, 4}}); + ConvGraphlet::init(g()); + + _conv->input(input(0)); + _conv->filter(input(1)); + _conv->bias(input(2)); + + output()->from(_conv); + } +}; + +class FailOfTestI3OGraph : public TestIsGraphlet<3>, public TestOGraphlet, public ConvGraphlet +{ +public: + FailOfTestI3OGraph() = default; + +public: + void init(void) + { + TestIsGraphlet<3>::init(g(), {{2, 3, 3, 4}, {1, 1}}); + TestOsGraphlet<1>::init(g(), {{2, 3, 3, 4}}); + ConvGraphlet::init(g()); + + _conv->input(input(0)); + _conv->filter(input(1)); + _conv->bias(input(2)); + + output()->from(_conv); + } +}; + +} // namespace + +TEST(TestIOGraphTest, IOGraph_init) +{ + TestOfTestIOGraph tg; + tg.init(); + + SUCCEED(); +} + +TEST(TestIOGraphTest, I2OGraph_init) +{ + TestOfTestI2OGraph tg; + tg.init(); + + SUCCEED(); +} + +TEST(TestIOGraphTest, I3OGraph_init) +{ + TestOfTestI3OGraph tg; + tg.init(); + + SUCCEED(); +} + +TEST(TestIOGraphTest, I3OGraph_input_number_mismatch_NEG) +{ + FailOfTestI3OGraph fg; + EXPECT_THROW(fg.init(), std::runtime_error); +} diff --git a/compiler/luci/testhelper/src/TestShape.test.cpp b/compiler/luci/testhelper/src/TestShape.test.cpp new file mode 100644 index 000000000..9838c6182 --- /dev/null +++ b/compiler/luci/testhelper/src/TestShape.test.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/test/TestShape.h" + +/** + * @note This file does not hold any test cases but provides methods for tests + */ + +namespace luci +{ +namespace test +{ + +void set_shape_vector(loco::TensorShape *shape, const ShapeU32 &values) +{ + uint32_t r = 0; + shape->rank(values.size()); + for (auto v : values) + shape->dim(r++).set(v); +} + +void set_shape_vector(luci::CircleConst *const_node, const ShapeI32 &values) +{ + const_node->rank(1); + const_node->dim(0).set(values.size()); + const_node->shape_status(luci::ShapeStatus::VALID); + const_node->dtype(loco::DataType::S32); + const_node->size<loco::DataType::S32>(values.size()); + uint32_t idx = 0; + for (auto val : values) + const_node->at<loco::DataType::S32>(idx++) = val; +} + +uint32_t num_elements(const ShapeU32 shape) +{ + uint32_t result = 1; + for (auto val : shape) + result = result * val; + return result; +} + +} // namespace test +} // namespace luci diff --git a/compiler/luci/tests/test.lst b/compiler/luci/tests/test.lst index 897d41983..a278fa256 100644 --- a/compiler/luci/tests/test.lst +++ b/compiler/luci/tests/test.lst @@ -51,6 +51,8 @@ addread(ExpandDims_000) addread(ExpandDims_001) addread(ExpandDims_002) addread(ExpandDims_003) +addread(ExpandDims_004) +addread(FakeQuant_000) addread(Fill_000) addread(Fill_001) addread(Floor_000) @@ -151,6 +153,7 @@ addread(SelectV2_002) addread(Shape_000) addread(Sin_000) addread(Slice_000) +addread(Slice_001) addread(Softmax_000) addread(Softmax_U8_000) addread(SpaceToBatchND_000) @@ -166,6 +169,7 @@ addread(Sqrt_000) addread(Square_000) addread(SquaredDifference_000) addread(Squeeze_000) +addread(Squeeze_001) addread(StridedSlice_000) addread(StridedSlice_001) addread(StridedSlice_002) @@ -268,6 +272,8 @@ addwrite(ExpandDims_000) addwrite(ExpandDims_001) addwrite(ExpandDims_002) addwrite(ExpandDims_003) +addwrite(ExpandDims_004) +addwrite(FakeQuant_000) addwrite(Fill_000) addwrite(Fill_001) addwrite(Floor_000) @@ -367,6 +373,7 @@ addwrite(SelectV2_002) addwrite(Shape_000) addwrite(Sin_000) addwrite(Slice_000) +addwrite(Slice_001) addwrite(Softmax_000) addwrite(Softmax_U8_000) addwrite(SpaceToBatchND_000) @@ -382,6 +389,7 @@ addwrite(Sqrt_000) addwrite(Square_000) addwrite(SquaredDifference_000) addwrite(Squeeze_000) +addwrite(Squeeze_001) addwrite(StridedSlice_000) addwrite(StridedSlice_001) addwrite(StridedSlice_002) |