diff options
Diffstat (limited to 'compiler/luci/import')
132 files changed, 1147 insertions, 1077 deletions
diff --git a/compiler/luci/import/CMakeLists.txt b/compiler/luci/import/CMakeLists.txt index 2ae00b837..642751ca6 100644 --- a/compiler/luci/import/CMakeLists.txt +++ b/compiler/luci/import/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(luci_import SHARED ${SOURCES}) target_include_directories(luci_import PRIVATE src) target_include_directories(luci_import PUBLIC include) target_link_libraries(luci_import PUBLIC luci_lang) +target_link_libraries(luci_import PUBLIC luci_profile) target_link_libraries(luci_import PUBLIC mio_circle) target_link_libraries(luci_import PRIVATE luci_env) target_link_libraries(luci_import PRIVATE luci_log) diff --git a/compiler/luci/import/include/luci/Import/CircleReader.h b/compiler/luci/import/include/luci/Import/CircleReader.h index 8e210dd77..b9697fb86 100644 --- a/compiler/luci/import/include/luci/Import/CircleReader.h +++ b/compiler/luci/import/include/luci/Import/CircleReader.h @@ -23,7 +23,6 @@ #include <luci/IR/AttrPadding.h> #include <luci/IR/CircleNode.h> #include <luci/IR/CircleQuantParam.h> -#include <luci/IR/CircleShapeSignature.h> #include <luci/IR/SparsityParam.h> #include <loco.h> @@ -64,6 +63,7 @@ private: using CircleTensors_t = std::vector<std::unique_ptr<circle::TensorT>>; using CircleOperators_t = std::vector<std::unique_ptr<circle::OperatorT>>; using CircleOperatorCodes_t = std::vector<std::unique_ptr<circle::OperatorCodeT>>; + using CircleMetadata_t = std::vector<std::unique_ptr<circle::MetadataT>>; using CircleSubGraphsPtr_t = flatbuffers::Vector<flatbuffers::Offset<circle::SubGraph>>; using CircleTensorsPtr_t = flatbuffers::Vector<flatbuffers::Offset<circle::Tensor>>; @@ -79,6 +79,8 @@ public: const std::vector<int32_t> &inputs() const { return _current_subgraph->inputs; } const std::vector<int32_t> &outputs() const { return _current_subgraph->outputs; } const std::string &name() const { return _current_subgraph->name; } + const circle::DataFormat &data_format() const { return _current_subgraph->data_format; } + const CircleMetadata_t &metadata() const { return _model->metadata; } const CircleTensorsPtr_t *tensors_ptr() const { return _tensors_ptr; } diff --git a/compiler/luci/import/include/luci/Import/GraphBuilder.h b/compiler/luci/import/include/luci/Import/GraphBuilder.h index 548264dac..0db612652 100644 --- a/compiler/luci/import/include/luci/Import/GraphBuilder.h +++ b/compiler/luci/import/include/luci/Import/GraphBuilder.h @@ -33,7 +33,13 @@ class GraphBuilder : public GraphBuilderBase public: virtual ~GraphBuilder() = default; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; + // common validate method to check number of inputs and single output + bool validate(const ValidateArgs &args, size_t input_cnt) const + { + return (args.op.inputs.size() == input_cnt && args.op.outputs.size() == 1); + } + + CircleNode *build(const circle::OperatorT &op, GraphBuilderContext *context) const final; private: virtual CircleNode *build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/include/luci/Import/GraphBuilderBase.h b/compiler/luci/import/include/luci/Import/GraphBuilderBase.h index a0cd008e0..ddd4445cd 100644 --- a/compiler/luci/import/include/luci/Import/GraphBuilderBase.h +++ b/compiler/luci/import/include/luci/Import/GraphBuilderBase.h @@ -19,6 +19,8 @@ #include "GraphBuilderContext.h" +#include <luci/IR/CircleNode.h> + #include <mio/circle/schema_generated.h> namespace luci @@ -38,7 +40,7 @@ struct GraphBuilderBase }; virtual bool validate(const ValidateArgs &) const = 0; - virtual void build(const circle::OperatorT &op, GraphBuilderContext *context) const = 0; + virtual CircleNode *build(const circle::OperatorT &op, GraphBuilderContext *context) const = 0; virtual ~GraphBuilderBase() = default; }; diff --git a/compiler/luci/import/include/luci/Import/GraphBuilderContext.h b/compiler/luci/import/include/luci/Import/GraphBuilderContext.h index 72e237abc..1673df43d 100644 --- a/compiler/luci/import/include/luci/Import/GraphBuilderContext.h +++ b/compiler/luci/import/include/luci/Import/GraphBuilderContext.h @@ -71,7 +71,7 @@ class GraphBuilderContext public: GraphBuilderContext(loco::Graph *g, CircleReader *reader, IndexNodeFinder *nodefinder, IndexTensorOutputs *tensoroutputs) - : _g(g), _reader(reader), _indexnodefinder(nodefinder), _indextensoroutputs(tensoroutputs) + : _g(g), _reader(reader), _indexnodefinder(nodefinder), _indextensoroutputs(tensoroutputs) { // DO NOTHING } diff --git a/compiler/luci/import/include/luci/Import/GraphBuilderMultiOutput.h b/compiler/luci/import/include/luci/Import/GraphBuilderMultiOutput.h new file mode 100644 index 000000000..6e8791b62 --- /dev/null +++ b/compiler/luci/import/include/luci/Import/GraphBuilderMultiOutput.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_GRAPH_BUILDER_MULTI_OUTPUT_H__ +#define __LUCI_IMPORT_GRAPH_BUILDER_MULTI_OUTPUT_H__ + +#include "GraphBuilderContext.h" +#include "GraphBuilderBase.h" + +#include <mio/circle/schema_generated.h> + +namespace luci +{ + +/** + * @brief Base of general multiple outputs graph builder(e.g., CircleIfGraphBuilder) + */ +class GraphBuilderMultiOutput : public GraphBuilderBase +{ +public: + virtual ~GraphBuilderMultiOutput() = default; + + CircleNode *build(const circle::OperatorT &op, GraphBuilderContext *context) const final; + +protected: + struct BuildNodeArgs + { + BuildNodeArgs(const circle::OperatorT &o, GraphBuilderContext *c, + const std::vector<CircleNode *> &i) + : op(o), context(c), input_nodes(i) + { + } + + const circle::OperatorT &op; + GraphBuilderContext *context; + const std::vector<CircleNode *> &input_nodes; + }; + + struct BuildOutArgs + { + BuildOutArgs(CircleNode *nd, uint32_t n) : node(nd), index(n) {} + + CircleNode *node; + uint32_t index; + }; + +private: + virtual CircleNode *build_node(const BuildNodeArgs &) const = 0; + virtual CircleNode *build_out(const BuildOutArgs &) const = 0; +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_GRAPH_BUILDER_MULTI_OUTPUT_H__ diff --git a/compiler/luci/import/include/luci/Import/Nodes.h b/compiler/luci/import/include/luci/Import/Nodes.h index 28741064e..b084c7dbc 100644 --- a/compiler/luci/import/include/luci/Import/Nodes.h +++ b/compiler/luci/import/include/luci/Import/Nodes.h @@ -27,6 +27,7 @@ #include "Nodes/CircleBatchToSpaceND.h" #include "Nodes/CircleBCQFullyConnected.h" #include "Nodes/CircleBCQGather.h" +#include "Nodes/CircleBidirectionalSequenceLSTM.h" #include "Nodes/CircleCast.h" #include "Nodes/CircleCeil.h" #include "Nodes/CircleConcatenation.h" @@ -42,6 +43,7 @@ #include "Nodes/CircleEqual.h" #include "Nodes/CircleExp.h" #include "Nodes/CircleExpandDims.h" +#include "Nodes/CircleFakeQuant.h" #include "Nodes/CircleFill.h" #include "Nodes/CircleFloor.h" #include "Nodes/CircleFloorDiv.h" diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h b/compiler/luci/import/include/luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h new file mode 100644 index 000000000..491517268 --- /dev/null +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_OP_CIRCLE_BIDIRECTIONALSEQUENCE_LSTM_H__ +#define __LUCI_IMPORT_OP_CIRCLE_BIDIRECTIONALSEQUENCE_LSTM_H__ + +#include "luci/Import/GraphBuilderMultiOutput.h" + +namespace luci +{ + +class CircleBidirectionalSequenceLSTMGraphBuilder : public GraphBuilderMultiOutput +{ +public: + bool validate(const ValidateArgs &args) const final; + +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_OP_CIRCLE_BIDIRECTIONALSEQUENCE_LSTM_H__ diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleCustom.h b/compiler/luci/import/include/luci/Import/Nodes/CircleCustom.h index 65745be4b..f0d7e303d 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleCustom.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleCustom.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_CUSTOM_H__ #define __LUCI_IMPORT_OP_CIRCLE_CUSTOM_H__ -#include "luci/Import/GraphBuilder.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleCustomGraphBuilder : public GraphBuilderBase +class CircleCustomGraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleFakeQuant.h b/compiler/luci/import/include/luci/Import/Nodes/CircleFakeQuant.h new file mode 100644 index 000000000..9d9f7b07b --- /dev/null +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleFakeQuant.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_OP_CIRCLE_FAKE_QUANT_H__ +#define __LUCI_IMPORT_OP_CIRCLE_FAKE_QUANT_H__ + +#include "luci/Import/GraphBuilder.h" + +namespace luci +{ + +class CircleFakeQuantGraphBuilder : public GraphBuilder +{ +public: + bool validate(const ValidateArgs &args) const final; + +private: + CircleNode *build_node(const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, + loco::Graph *graph) const final; +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_OP_CIRCLE_FAKE_QUANT_H__ diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleIf.h b/compiler/luci/import/include/luci/Import/Nodes/CircleIf.h index 8faf09cae..94052f5be 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleIf.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleIf.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_IF_H__ #define __LUCI_IMPORT_OP_CIRCLE_IF_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleIfGraphBuilder : public GraphBuilderBase +class CircleIfGraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV4.h b/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV4.h index f193aae35..4e8388b3e 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV4.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV4.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_NON_MAX_SUPPRESSION_V4_H__ #define __LUCI_IMPORT_OP_CIRCLE_NON_MAX_SUPPRESSION_V4_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleNonMaxSuppressionV4GraphBuilder : public GraphBuilderBase +class CircleNonMaxSuppressionV4GraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV5.h b/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV5.h index 62be0758e..4120a30eb 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV5.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV5.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_NON_MAX_SUPPRESSION_V5_H__ #define __LUCI_IMPORT_OP_CIRCLE_NON_MAX_SUPPRESSION_V5_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleNonMaxSuppressionV5GraphBuilder : public GraphBuilderBase +class CircleNonMaxSuppressionV5GraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleSplit.h b/compiler/luci/import/include/luci/Import/Nodes/CircleSplit.h index 3395e40fd..5b45c9a9e 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleSplit.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleSplit.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_SPLIT_H__ #define __LUCI_IMPORT_OP_CIRCLE_SPLIT_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleSplitGraphBuilder : public GraphBuilderBase +class CircleSplitGraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleSplitV.h b/compiler/luci/import/include/luci/Import/Nodes/CircleSplitV.h index 3e53df362..de712f90c 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleSplitV.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleSplitV.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_SPLIT_V_H__ #define __LUCI_IMPORT_OP_CIRCLE_SPLIT_V_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleSplitVGraphBuilder : public GraphBuilderBase +class CircleSplitVGraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleTopKV2.h b/compiler/luci/import/include/luci/Import/Nodes/CircleTopKV2.h index 8ec3f3311..b4ad97130 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleTopKV2.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleTopKV2.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_TOPK_V2_H__ #define __LUCI_IMPORT_OP_CIRCLE_TOPK_V2_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleTopKV2GraphBuilder : public GraphBuilderBase +class CircleTopKV2GraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleUnique.h b/compiler/luci/import/include/luci/Import/Nodes/CircleUnique.h index ed5b5035d..40e75ec73 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleUnique.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleUnique.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_UNIQUE_H__ #define __LUCI_IMPORT_OP_CIRCLE_UNIQUE_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleUniqueGraphBuilder : public GraphBuilderBase +class CircleUniqueGraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleUnpack.h b/compiler/luci/import/include/luci/Import/Nodes/CircleUnpack.h index f1a21de22..0b623655f 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleUnpack.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleUnpack.h @@ -17,17 +17,19 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_UNPACK_H__ #define __LUCI_IMPORT_OP_CIRCLE_UNPACK_H__ -#include "luci/Import/GraphBuilderBase.h" +#include "luci/Import/GraphBuilderMultiOutput.h" namespace luci { -class CircleUnpackGraphBuilder : public GraphBuilderBase +class CircleUnpackGraphBuilder : public GraphBuilderMultiOutput { public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; +private: + CircleNode *build_node(const BuildNodeArgs &) const final; + CircleNode *build_out(const BuildOutArgs &) const final; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleWhile.h b/compiler/luci/import/include/luci/Import/Nodes/CircleWhile.h index 68c56b3c6..69d23f823 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleWhile.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleWhile.h @@ -27,7 +27,7 @@ class CircleWhileGraphBuilder : public GraphBuilderBase public: bool validate(const ValidateArgs &args) const final; - void build(const circle::OperatorT &op, GraphBuilderContext *context) const final; + CircleNode *build(const circle::OperatorT &op, GraphBuilderContext *context) const final; }; } // namespace luci diff --git a/compiler/luci/import/src/CircleImportMetadata.cpp b/compiler/luci/import/src/CircleImportMetadata.cpp new file mode 100644 index 000000000..f68f3301a --- /dev/null +++ b/compiler/luci/import/src/CircleImportMetadata.cpp @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleImportMetadata.h" + +#include <vector> + +namespace +{ + +uint32_t read_u32(const std::vector<uint8_t> &buffer, uint32_t idx) +{ + uint32_t val = 0; + val += (buffer.at(idx + 0) << 0 * 8); + val += (buffer.at(idx + 1) << 1 * 8); + val += (buffer.at(idx + 2) << 2 * 8); + val += (buffer.at(idx + 3) << 3 * 8); + return val; +} + +} // namespace + +namespace +{ + +// 'source_table' is decoded to std::map<uint32_t, std::string> format. +const std::map<uint32_t, std::string> +decoded_source_table(const std::vector<uint8_t> &source_table_data) +{ + std::map<uint32_t, std::string> source_id_name_map; + uint32_t idx = 0; + + if (source_table_data.size() < 4) + throw std::runtime_error("Source table decode error : invalid entry number"); + + uint32_t entry_number = read_u32(source_table_data, idx); + idx += sizeof(uint32_t); + + while (idx < source_table_data.size()) + { + if (idx + 2 * sizeof(uint32_t) > source_table_data.size()) + throw std::runtime_error("Source table decode error : invalid entry item"); + + uint32_t id = read_u32(source_table_data, idx); + idx += sizeof(uint32_t); + + uint32_t length = read_u32(source_table_data, idx); + idx += sizeof(uint32_t); + + if (idx + sizeof(char) * length > source_table_data.size()) + throw std::runtime_error("Source table decode error : invalid entry data"); + + // The last character of name is '\0'. + // However, as std::string do not use '\0' for finding the end of string, + // we ignore the character and do not include it in the string. + std::string origin_name; + for (uint32_t j = 0; j < length - 1; ++j) + origin_name += source_table_data.at(idx + j); + assert(source_table_data.at(idx + length - 1) == '\0'); + idx += sizeof(char) * length; + + if (source_id_name_map.insert({id, origin_name}).second == false) + throw std::runtime_error("Source table decode error : duplicated origin ID"); + } + + if (idx != source_table_data.size()) + throw std::runtime_error("Source table decode error : data size invalid"); + + if (source_id_name_map.size() != entry_number) + throw std::runtime_error("Source table decode error : result size mismatch"); + + return source_id_name_map; +} + +// 'op_table' is decoded to std::map<uint32_t, std::set<uint32_t>> format. +const std::map<uint32_t, std::set<uint32_t>> +decoded_op_table(const std::vector<uint8_t> &op_table_data) +{ + std::map<uint32_t, std::set<uint32_t>> node_source_ids_map; + uint32_t idx = 0; + + if (op_table_data.size() < 4) + throw std::runtime_error("Op table decode error : invalid entry number"); + + uint32_t entry_number = read_u32(op_table_data, idx); + idx += sizeof(uint32_t); + + while (idx < op_table_data.size()) + { + if (idx + 2 * sizeof(uint32_t) > op_table_data.size()) + throw std::runtime_error("Op table decode error : invalid entry item"); + + uint32_t id = read_u32(op_table_data, idx); + idx += sizeof(uint32_t); + + uint32_t node_num = read_u32(op_table_data, idx); + idx += sizeof(uint32_t); + + if (idx + sizeof(uint32_t) * node_num > op_table_data.size()) + throw std::runtime_error("Source table decode error : invalid entry data"); + + std::set<uint32_t> source_ids; + for (uint32_t j = 0; j < node_num; ++j) + { + uint32_t origin = read_u32(op_table_data, idx); + idx += sizeof(uint32_t); + + source_ids.insert(origin); + } + + if (node_source_ids_map.insert({id, source_ids}).second == false) + throw std::runtime_error("Op table decode error : duplicated origin ID"); + } + + if (idx != op_table_data.size()) + throw std::runtime_error("Op table decode error : data size invalid"); + + if (node_source_ids_map.size() != entry_number) + throw std::runtime_error("Op table decode error : entry number invalid"); + + return node_source_ids_map; +} + +} // namespace + +namespace luci +{ + +CircleImportMetadata::CircleImportMetadata(const luci::CircleReader &reader) +{ + const auto &metadata = reader.metadata(); + for (uint32_t i = 0; i < metadata.size(); ++i) + { + const circle::MetadataT &meta = *metadata[i]; + + assert(meta.buffer < reader.buffers().size()); + const std::vector<uint8_t> &buffer = reader.buffers()[meta.buffer]->data; + + if (meta.name.compare("ONE_op_table") == 0) + _op_table = decoded_op_table(buffer); + else if (meta.name.compare("ONE_source_table") == 0) + _source_table = decoded_source_table(buffer); + } +} + +const OriginTable CircleImportMetadata::origin_table(void) +{ + OriginTable origin_table; + + if (_op_table.size() > 0 && _source_table.size() > 0) + { + for (auto &kv : _op_table) + { + const auto node_id = kv.first; + const auto &source_ids = kv.second; + + std::vector<std::shared_ptr<CircleNodeOrigin>> origins; + for (auto source_id : source_ids) + { + const auto source_name = _source_table.at(source_id); + origins.push_back(single_origin(source_id, source_name)); + } + + auto origin = composite_origin(origins); + origin_table.emplace(node_id, origin); + } + } + + return origin_table; +} + +} // namespace luci diff --git a/compiler/luci/import/src/CircleImportMetadata.h b/compiler/luci/import/src/CircleImportMetadata.h new file mode 100644 index 000000000..80176db94 --- /dev/null +++ b/compiler/luci/import/src/CircleImportMetadata.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_CIRCLE_IMPORT_METADATA_H__ +#define __LUCI_CIRCLE_IMPORT_METADATA_H__ + +#include "luci/Import/CircleReader.h" + +#include <luci/Profile/CircleNodeOrigin.h> + +#include <map> +#include <set> +#include <string> + +namespace luci +{ + +using OriginTable = std::map<uint32_t, std::shared_ptr<CircleNodeOrigin>>; + +class CircleImportMetadata +{ +public: + CircleImportMetadata() = delete; + + CircleImportMetadata(const luci::CircleReader &reader); + +public: + /** + * @brief Create origin table using _source_table and _op_table in CircleImportMetadata + * @note For creating origin table, both _op_table and _source_table should exist. + * If one of them does not exist, empty table is returned. + */ + const OriginTable origin_table(void); + +private: + // Decoded metadata is stored + std::map<uint32_t, std::string> _source_table; + std::map<uint32_t, std::set<uint32_t>> _op_table; +}; + +} // namespace luci + +#endif // __LUCI_CIRCLE_IMPORT_METADATA_H__ diff --git a/compiler/luci/import/src/CircleReader.cpp b/compiler/luci/import/src/CircleReader.cpp index b33c920b1..861c1bbe3 100644 --- a/compiler/luci/import/src/CircleReader.cpp +++ b/compiler/luci/import/src/CircleReader.cpp @@ -190,19 +190,19 @@ luci_sparse_index_vector(const circle::SparseIndexVectorUnion &sparse_index_vect case circle::SparseIndexVector_Int32Vector: { const auto const_vec_ptr = - static_cast<const void *>(&(sparse_index_vector.AsInt32Vector()->values)); + static_cast<const void *>(&(sparse_index_vector.AsInt32Vector()->values)); return SparseIndexVector{SparseIndexVectorType::I32, const_vec_ptr}; } case circle::SparseIndexVector_Uint16Vector: { const auto const_vec_ptr = - static_cast<const void *>(&(sparse_index_vector.AsUint16Vector()->values)); + static_cast<const void *>(&(sparse_index_vector.AsUint16Vector()->values)); return SparseIndexVector{SparseIndexVectorType::U16, const_vec_ptr}; } case circle::SparseIndexVector_Uint8Vector: { const auto const_vec_ptr = - static_cast<const void *>(&(sparse_index_vector.AsUint8Vector()->values)); + static_cast<const void *>(&(sparse_index_vector.AsUint8Vector()->values)); return SparseIndexVector{SparseIndexVectorType::U8, const_vec_ptr}; } default: @@ -262,15 +262,19 @@ void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node) node->name(tensor_name(tensor)); node->dtype(luci_datatype(tensor.type)); + assert(tensor.shape_signature.size() == 0 || + tensor.shape_signature.size() == tensor.shape.size()); + std::vector<int32_t> dims = tensor.shape; // in NHWC node->rank(dims.size()); for (uint32_t r = 0; r < dims.size(); ++r) { - node->dim(r) = loco::Dimension(dims[r]); + if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1) + node->dim(r).unset(); + else + node->dim(r).set(dims[r]); } - node->shape_signature(tensor.shape_signature); - const auto *quantization = tensor.quantization.get(); if (quantization != nullptr) { diff --git a/compiler/luci/import/src/GraphBuilder.cpp b/compiler/luci/import/src/GraphBuilder.cpp index 80a9f986a..356501c2f 100644 --- a/compiler/luci/import/src/GraphBuilder.cpp +++ b/compiler/luci/import/src/GraphBuilder.cpp @@ -21,7 +21,7 @@ namespace luci { -void GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const +CircleNode *GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const { LOGGER(l); @@ -47,7 +47,11 @@ void GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *conte else { // If there is no tensor, insert CircleOutputExclude. - input_nodes.push_back(context->graph()->nodes()->create<luci::CircleOutputExclude>()); + auto *node = context->graph()->nodes()->create<luci::CircleOutputExclude>(); + // CircleOutputExclude doesn't need a type, but since all nodes must have a type, + // a dummy type is inserted. + node->dtype(loco::DataType::FLOAT32); + input_nodes.push_back(node); } } @@ -73,6 +77,8 @@ void GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *conte { context->nodefinder()->enroll(outputs[0], node); } + + return node; } } // namespace luci diff --git a/compiler/luci/import/src/GraphBuilderMultiOutput.cpp b/compiler/luci/import/src/GraphBuilderMultiOutput.cpp new file mode 100644 index 000000000..9b42e997e --- /dev/null +++ b/compiler/luci/import/src/GraphBuilderMultiOutput.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/GraphBuilderMultiOutput.h" + +#include <luci/Log.h> + +namespace luci +{ + +CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op, + GraphBuilderContext *context) const +{ + LOGGER(l); + + assert(context != nullptr); + + const std::vector<int32_t> &inputs = op.inputs; + const std::vector<int32_t> &outputs = op.outputs; + const auto &tensors = context->reader()->tensors(); + const auto &opcodes = context->reader()->opcodes(); + auto tensors_ptr = context->reader()->tensors_ptr(); + assert(tensors_ptr != nullptr); + + std::vector<CircleNode *> input_nodes; + for (const int32_t input_tensor_index : inputs) + { + if (input_tensor_index >= 0) + { + auto input = context->nodefinder()->node(input_tensor_index); + if (input == nullptr) + INFO(l) << "[luci] Warning: input node is null " << input_tensor_index << std::endl; + input_nodes.push_back(input); + } + else + { + // If there is no tensor, insert CircleOutputExclude. + auto *node = context->graph()->nodes()->create<luci::CircleOutputExclude>(); + // CircleOutputExclude doesn't need a type, but since all nodes must have a type, + // a dummy type is inserted. + node->dtype(loco::DataType::FLOAT32); + input_nodes.push_back(node); + } + } + + BuildNodeArgs bna(op, context, input_nodes); + auto *node = build_node(bna); + + uint32_t output_count = outputs.size(); + assert(output_count > 0); + { + // Let's use attributes from output 0 for this node + const circle::TensorT &output_tensor = *tensors[outputs[0]]; + node->name(tensor_name(output_tensor)); + node->dtype(luci_datatype(output_tensor.type)); + + // mark operator version + node->op_version(opcodes[op.opcode_index].get()->version); + + // NOTE We don't set quantization for multiple output nodes but to virtual outputs + } + + // Create virtual outputs of Virtual Output node(s) + for (uint32_t n = 0; n < output_count; ++n) + { + const circle::TensorT &output_tensor = *tensors[outputs[n]]; + + BuildOutArgs boa(node, n); + auto *nodeout = build_out(boa); + + copy_tensor_attributes(output_tensor, nodeout); + // mark shape_status + if (tensors_ptr->Get(outputs[n])->shape() == nullptr) + nodeout->shape_status(ShapeStatus::NOSHAPE); + else + nodeout->shape_status(ShapeStatus::VALID); + + context->nodefinder()->enroll(outputs[n], nodeout); + } + + return node; +} + +} // namespace luci diff --git a/compiler/luci/import/src/GraphBuilderRegistry.cpp b/compiler/luci/import/src/GraphBuilderRegistry.cpp index d598d30f4..7f98aab78 100644 --- a/compiler/luci/import/src/GraphBuilderRegistry.cpp +++ b/compiler/luci/import/src/GraphBuilderRegistry.cpp @@ -37,6 +37,7 @@ GraphBuilderRegistry::GraphBuilderRegistry() CIRCLE_NODE(BATCH_TO_SPACE_ND, CircleBatchToSpaceNDGraphBuilder); // 37 CIRCLE_NODE(BCQ_FULLY_CONNECTED, CircleBCQFullyConnectedGraphBuilder); // 253 CIRCLE_NODE(BCQ_GATHER, CircleBCQGatherGraphBuilder); // 252 + CIRCLE_NODE(BIDIRECTIONAL_SEQUENCE_LSTM, CircleBidirectionalSequenceLSTMGraphBuilder); // 52 CIRCLE_NODE(CAST, CircleCastGraphBuilder); // 53 CIRCLE_NODE(CEIL, CircleCeilGraphBuilder); // 104 CIRCLE_NODE(CUSTOM, CircleCustomGraphBuilder); // 32 @@ -51,6 +52,7 @@ GraphBuilderRegistry::GraphBuilderRegistry() CIRCLE_NODE(EQUAL, CircleEqualGraphBuilder); // 71 CIRCLE_NODE(EXP, CircleExpGraphBuilder); // 47 CIRCLE_NODE(EXPAND_DIMS, CircleExpandDimsGraphBuilder); // 70 + CIRCLE_NODE(FAKE_QUANT, CircleFakeQuantGraphBuilder); // 80 CIRCLE_NODE(FILL, CircleFillGraphBuilder); // 94 CIRCLE_NODE(FLOOR, CircleFloorGraphBuilder); // 8 CIRCLE_NODE(FLOOR_DIV, CircleFloorDivGraphBuilder); // 90 @@ -155,9 +157,7 @@ GraphBuilderRegistry::GraphBuilderRegistry() // BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN = 35, // BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN = 46, // BuiltinOperator_DELEGATE = 51, - // BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM = 52, // BuiltinOperator_ARG_MAX = 56, - // BuiltinOperator_FAKE_QUANT = 80, // BuiltinOperator_QUANTIZE = 114, // BuiltinOperator_HARD_SWISH = 117, // BuiltinOperator_DENSIFY = 124, diff --git a/compiler/luci/import/src/Importer.cpp b/compiler/luci/import/src/Importer.cpp index ab89f3587..193afffcb 100644 --- a/compiler/luci/import/src/Importer.cpp +++ b/compiler/luci/import/src/Importer.cpp @@ -15,6 +15,7 @@ */ #include "luci/Importer.h" +#include "CircleImportMetadata.h" #include "PostImport.h" #include "luci/Import/GraphBuilder.h" @@ -25,6 +26,8 @@ #include <luci/IR/Module.h> #include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeID.h> +#include <luci/Profile/CircleNodeOrigin.h> #include <luci/Log.h> #include <luci/LogHelper.h> @@ -50,6 +53,7 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r const auto &tensors = reader.tensors(); auto tensors_ptr = reader.tensors_ptr(); assert(tensors_ptr != nullptr); + auto circle_metadata = std::make_unique<luci::CircleImportMetadata>(reader); // build a cache to identify if a tensor is output of an operator // if this is set, we should not create a CircleConst for this tensor @@ -96,12 +100,20 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r // Data type graph_input->dtype(input_node->dtype()); + assert(tensor.shape_signature.size() == 0 || + tensor.shape_signature.size() == tensor.shape.size()); + // Shape of GraphInput auto input_shape = std::make_unique<loco::TensorShape>(); const std::vector<int32_t> &input_dims = tensor.shape; // in NHWC input_shape->rank(input_dims.size()); for (uint32_t r = 0; r < input_dims.size(); ++r) - input_shape->dim(r) = loco::Dimension(input_dims[r]); + { + if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1) + input_shape->dim(r).unset(); + else + input_shape->dim(r).set(input_dims[r]); + } graph_input->shape(std::move(input_shape)); } @@ -117,6 +129,7 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r // Note that operators in model are stored in execution order. This means that when importing // an operator, its input operators have already been imported. We exploit this fact to set up // node's inputs right after creating the node. + auto origin_table = circle_metadata->origin_table(); for (uint32_t i = 0; i < operators.size(); ++i) { const circle::OperatorT &op = *operators[i]; @@ -130,7 +143,12 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r throw oops::UserExn("Invalid operator", reader.opcode_name(op)); } - builder->build(op, &gb_context); + auto built_op = builder->build(op, &gb_context); + set_node_id(built_op, i); + if (origin_table.find(i) != origin_table.end()) + add_origin(built_op, origin_table.at(i)); + else + add_origin(built_op, luci::single_origin(i, built_op->name())); } else { @@ -169,19 +187,28 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r // set the graph output name and node object auto graph_output = graph->outputs()->create(); std::string tname = luci::tensor_name(tensor); - graph_output->name("output_" + tname); + assert(tname.length() > 0); + graph_output->name(tname); luci::copy_tensor_attributes(tensor, output_node); // Set GraphInputOutputIndex for graph output_node->index(graph_output->index()); + assert(tensor.shape_signature.size() == 0 || + tensor.shape_signature.size() == tensor.shape.size()); + // Shape of Output auto output_shape = std::make_unique<loco::TensorShape>(); const std::vector<int32_t> &output_dims = tensor.shape; // in NHWC output_shape->rank(output_dims.size()); for (uint32_t r = 0; r < output_dims.size(); ++r) - output_shape->dim(r) = loco::Dimension(output_dims[r]); + { + if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1) + output_shape->dim(r).unset(); + else + output_shape->dim(r).set(output_dims[r]); + } graph_output->shape(std::move(output_shape)); // Data type diff --git a/compiler/luci/import/src/Nodes/CircleAbs.cpp b/compiler/luci/import/src/Nodes/CircleAbs.cpp index 3556dc7fa..2a1601a21 100644 --- a/compiler/luci/import/src/Nodes/CircleAbs.cpp +++ b/compiler/luci/import/src/Nodes/CircleAbs.cpp @@ -24,11 +24,8 @@ namespace luci { bool CircleAbsGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - // TODO Support type check - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleAbsGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleAdd.cpp b/compiler/luci/import/src/Nodes/CircleAdd.cpp index b767d4af2..94cbdf081 100644 --- a/compiler/luci/import/src/Nodes/CircleAdd.cpp +++ b/compiler/luci/import/src/Nodes/CircleAdd.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleAddGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleAddGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleArgMax.cpp b/compiler/luci/import/src/Nodes/CircleArgMax.cpp index 10e8516f4..fd8a84289 100644 --- a/compiler/luci/import/src/Nodes/CircleArgMax.cpp +++ b/compiler/luci/import/src/Nodes/CircleArgMax.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleArgMaxGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleArgMaxGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleArgMin.cpp b/compiler/luci/import/src/Nodes/CircleArgMin.cpp index 5ff534dbb..63ca8db03 100644 --- a/compiler/luci/import/src/Nodes/CircleArgMin.cpp +++ b/compiler/luci/import/src/Nodes/CircleArgMin.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleArgMinGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleArgMinGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp b/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp index ad011f71f..a351cf5e7 100644 --- a/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp +++ b/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp @@ -23,10 +23,7 @@ namespace luci bool CircleAveragePool2DGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleAveragePool2DGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp index 16ecebd5c..4c86399ce 100644 --- a/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp +++ b/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleBCQFullyConnectedGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 5) - return false; - - return true; + return GraphBuilder::validate(args, 5); } CircleNode *CircleBCQFullyConnectedGraphBuilder::build_node(const circle::OperatorT &op, @@ -43,15 +40,6 @@ CircleNode *CircleBCQFullyConnectedGraphBuilder::build_node(const circle::Operat node->bias(inputs.at(3)); node->weights_clusters(inputs.at(4)); - // TODO Find and move to appropriate place for setting optional input - if (auto bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias())) - { - // bias is not used for type inference, but node itself should have a type - bias->dtype(loco::DataType::FLOAT32); - - // bias is not used for shape inference - } - const auto *options = op.builtin_options.AsBCQFullyConnectedOptions(); node->weights_hidden_size(options->weights_hidden_size); node->fusedActivationFunction(luci_actfunc(options->fused_activation_function)); diff --git a/compiler/luci/import/src/Nodes/CircleBCQGather.cpp b/compiler/luci/import/src/Nodes/CircleBCQGather.cpp index 464f1ac18..ee1358197 100644 --- a/compiler/luci/import/src/Nodes/CircleBCQGather.cpp +++ b/compiler/luci/import/src/Nodes/CircleBCQGather.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleBCQGatherGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 4) - return false; - - return true; + return GraphBuilder::validate(args, 4); } CircleNode *CircleBCQGatherGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp b/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp index 330775691..390719061 100644 --- a/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp +++ b/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp @@ -23,10 +23,7 @@ namespace luci bool CircleBatchMatMulGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleBatchMatMulGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp b/compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp new file mode 100644 index 000000000..f8bdcff72 --- /dev/null +++ b/compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h" + +#include <luci/IR/Nodes/CircleBidirectionalSequenceLSTM.h> +#include <luci/IR/Nodes/CircleBidirectionalSequenceLSTMOut.h> + +#include <loco.h> + +namespace luci +{ + +bool CircleBidirectionalSequenceLSTMGraphBuilder::validate(const ValidateArgs &args) const +{ + if (args.op.inputs.size() != 48) + return false; + if (args.op.outputs.size() != 2) + return false; + + return true; +} + +CircleNode *CircleBidirectionalSequenceLSTMGraphBuilder::build_node(const BuildNodeArgs &bna) const +{ + auto *node = bna.context->graph()->nodes()->create<CircleBidirectionalSequenceLSTM>(); + auto &inputs = bna.input_nodes; + node->input(inputs.at(0)); + node->fw_input_to_input_weights(inputs.at(1)); // Optional + node->fw_input_to_cell_weights(inputs.at(2)); + node->fw_input_to_forget_weights(inputs.at(3)); + node->fw_input_to_output_weights(inputs.at(4)); + node->fw_recurrent_to_input_weights(inputs.at(5)); // Optional + node->fw_recurrent_to_cell_weights(inputs.at(6)); + node->fw_recurrent_to_forget_weights(inputs.at(7)); + node->fw_recurrent_to_output_weights(inputs.at(8)); + node->fw_cell_to_input_weights(inputs.at(9)); // Optional + node->fw_cell_to_forget_weights(inputs.at(10)); // Optional + node->fw_cell_to_output_weights(inputs.at(11)); // Optional + node->fw_input_gate_bias(inputs.at(12)); // Optional + node->fw_forget_gate_bias(inputs.at(13)); + node->fw_cell_gate_bias(inputs.at(14)); + node->fw_output_gate_bias(inputs.at(15)); + node->fw_projection_weights(inputs.at(16)); // Optional + node->fw_projection_bias(inputs.at(17)); // Optional + node->bw_input_to_input_weights(inputs.at(18)); // Optional + node->bw_input_to_cell_weights(inputs.at(19)); + node->bw_input_to_forget_weights(inputs.at(20)); + node->bw_input_to_output_weights(inputs.at(21)); + node->bw_recurrent_to_input_weights(inputs.at(22)); // Optional + node->bw_recurrent_to_cell_weights(inputs.at(23)); + node->bw_recurrent_to_forget_weights(inputs.at(24)); + node->bw_recurrent_to_output_weights(inputs.at(25)); + node->bw_cell_to_input_weights(inputs.at(26)); // Optional + node->bw_cell_to_forget_weights(inputs.at(27)); // Optional + node->bw_cell_to_output_weights(inputs.at(28)); // Optional + node->bw_input_gate_bias(inputs.at(29)); // Optional + node->bw_forget_gate_bias(inputs.at(30)); + node->bw_cell_gate_bias(inputs.at(31)); + node->bw_output_gate_bias(inputs.at(32)); + node->bw_projection_weights(inputs.at(33)); // Optional + node->bw_projection_bias(inputs.at(34)); // Optional + node->fw_activation_state(inputs.at(35)); + node->fw_cell_state(inputs.at(36)); + node->bw_activation_state(inputs.at(37)); + node->bw_cell_state(inputs.at(38)); + + node->auxillary_input(inputs.at(39)); // Optional + node->fw_auxillary_input_to_input_weights(inputs.at(40)); // Optional + node->fw_auxillary_input_to_forget_weights(inputs.at(41)); // Optional + node->fw_auxillary_input_to_cell_weights(inputs.at(42)); // Optional + node->fw_auxillary_input_to_output_weights(inputs.at(43)); // Optional + node->bw_auxillary_input_to_input_weights(inputs.at(44)); // Optional + node->bw_auxillary_input_to_forget_weights(inputs.at(45)); // Optional + node->bw_auxillary_input_to_cell_weights(inputs.at(46)); // Optional + node->bw_auxillary_input_to_output_weights(inputs.at(47)); // Optional + + const auto *options = bna.op.builtin_options.AsBidirectionalSequenceLSTMOptions(); + node->fusedActivationFunction(luci_actfunc(options->fused_activation_function)); + node->cell_clip(options->cell_clip); + node->proj_clip(options->proj_clip); + node->merge_outputs(options->merge_outputs); + node->time_major(options->time_major); + node->asymmetric_quantize_inputs(options->asymmetric_quantize_inputs); + + return node; +} + +CircleNode *CircleBidirectionalSequenceLSTMGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleBidirectionalSequenceLSTMOut>(); + + nodeout->input(boa.node); + nodeout->index(boa.index); + + return nodeout; +} + +} // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleCast.cpp b/compiler/luci/import/src/Nodes/CircleCast.cpp index 7bdb63044..3e8c08bfa 100644 --- a/compiler/luci/import/src/Nodes/CircleCast.cpp +++ b/compiler/luci/import/src/Nodes/CircleCast.cpp @@ -30,14 +30,13 @@ bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const { LOGGER(l); + if (!GraphBuilder::validate(args, 1)) + return false; + auto settings = luci::UserSettings::settings(); const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - if (inputs.size() != 1) - return false; - if (outputs.size() != 1) - return false; // NOTE real models do have type mismatch const auto *options = args.op.builtin_options.AsCastOptions(); diff --git a/compiler/luci/import/src/Nodes/CircleCeil.cpp b/compiler/luci/import/src/Nodes/CircleCeil.cpp index 2e1aaa295..d439f41cd 100644 --- a/compiler/luci/import/src/Nodes/CircleCeil.cpp +++ b/compiler/luci/import/src/Nodes/CircleCeil.cpp @@ -25,16 +25,8 @@ namespace luci bool CircleCeilGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 1) - return false; - if (outputs.size() != 1) - return false; - // TODO dtype check - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleCeilGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleConv2D.cpp b/compiler/luci/import/src/Nodes/CircleConv2D.cpp index 9516ef16a..8cbecdc00 100644 --- a/compiler/luci/import/src/Nodes/CircleConv2D.cpp +++ b/compiler/luci/import/src/Nodes/CircleConv2D.cpp @@ -28,10 +28,7 @@ namespace luci bool CircleConv2DGraphBuilder::validate(const ValidateArgs &args) const { // Circle Conv2D may not have a bias but we won't support this - if (args.op.inputs.size() != 3) - return false; - - return true; + return GraphBuilder::validate(args, 3); } CircleNode *CircleConv2DGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleCos.cpp b/compiler/luci/import/src/Nodes/CircleCos.cpp index 27d60c62c..9705202ee 100644 --- a/compiler/luci/import/src/Nodes/CircleCos.cpp +++ b/compiler/luci/import/src/Nodes/CircleCos.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleCosGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleCosGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleCustom.cpp b/compiler/luci/import/src/Nodes/CircleCustom.cpp index d541ee87b..01ac3e2a0 100644 --- a/compiler/luci/import/src/Nodes/CircleCustom.cpp +++ b/compiler/luci/import/src/Nodes/CircleCustom.cpp @@ -27,62 +27,39 @@ bool CircleCustomGraphBuilder::validate(const ValidateArgs &) const return true; } -void CircleCustomGraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleCustomGraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); + uint32_t input_count = bna.op.inputs.size(); + uint32_t output_count = bna.op.outputs.size(); - auto graph = context->graph(); + auto *node = bna.context->graph()->nodes()->create<CircleCustom>(input_count, output_count); - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); + for (uint32_t idx = 0; idx < input_count; ++idx) + { + node->inputs(idx, bna.input_nodes[idx]); + } - // Create CircleCustom - const auto &opcodes = context->reader()->opcodes(); - const uint32_t opcode_index = op.opcode_index; + const auto &opcodes = bna.context->reader()->opcodes(); + const uint32_t opcode_index = bna.op.opcode_index; const circle::OperatorCodeT &opcode = *opcodes[opcode_index]; - auto *node = graph->nodes()->create<CircleCustom>(inputs.size()); - uint32_t input_idx = 0; - for (const int32_t input_tensor_index : inputs) - { - node->inputs(input_idx++, context->nodefinder()->node(input_tensor_index)); - } - node->custom_options(std::vector<uint8_t>{op.custom_options.begin(), op.custom_options.end()}); + node->custom_options( + std::vector<uint8_t>{bna.op.custom_options.begin(), bna.op.custom_options.end()}); node->custom_code(opcode.custom_code); - // Operator version of custom is always 1, so do nothing - uint32_t output_count = outputs.size(); + // NOTE Operator version of custom is always 1 - assert(output_count > 0); - { - // Let's use attributes from output 0 for this node - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->dtype(luci_datatype(output_tensor.type)); - } - - // Create virtual outputs of Custom - for (uint32_t n = 0; n < output_count; ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; + return node; +} - auto *nodeout = graph->nodes()->create<CircleCustomOut>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); +CircleNode *CircleCustomGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleCustomOut>(); - nodeout->input(node); - nodeout->index(n); + nodeout->input(boa.node); + nodeout->index(boa.index); - context->nodefinder()->enroll(outputs[n], nodeout); - } + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp index 49d31bb99..49eb30a83 100644 --- a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp +++ b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp @@ -27,17 +27,13 @@ namespace luci bool CircleDepthToSpaceGraphBuilder::validate(const ValidateArgs &args) const { + if (!GraphBuilder::validate(args, 1)) + return false; + const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; const auto *options = args.op.builtin_options.AsDepthToSpaceOptions(); - - if (inputs.size() != 1) - return false; - - if (outputs.size() != 1) - return false; - const auto &tensors = args.reader.tensors(); if (tensors[outputs[0]]->type != tensors[inputs.at(0)]->type) diff --git a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp index 53f85f2f5..727487c6a 100644 --- a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp +++ b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp @@ -32,6 +32,32 @@ bool CircleDepthwiseConv2DGraphBuilder::validate(const ValidateArgs &args) const if (args.op.outputs.size() != 1) return false; + const auto &tensors = args.reader.tensors(); + + // input shape + const auto &input = tensors.at(args.op.inputs.at(0)); + const auto &input_shape = input->shape; + + // input shape must be rank 4 + if (input_shape.size() != 4) + return false; + + // filter shape + const auto &filter = tensors.at(args.op.inputs.at(1)); + const auto &filter_shape = filter->shape; + + // filter shape must be rank 4 + if (filter_shape.size() != 4) + return false; + + // multiplier + const auto *options = args.op.builtin_options.AsDepthwiseConv2DOptions(); + const auto &multiplier = options->depth_multiplier; + + // filter represents as [1, H, W, C*M] where M is multiplier. + if (filter_shape.at(3) != input_shape.at(3) * multiplier) + return false; + return true; } diff --git a/compiler/luci/import/src/Nodes/CircleDequantize.cpp b/compiler/luci/import/src/Nodes/CircleDequantize.cpp index 1936da97c..3db546bd0 100644 --- a/compiler/luci/import/src/Nodes/CircleDequantize.cpp +++ b/compiler/luci/import/src/Nodes/CircleDequantize.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleDequantizeGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleDequantizeGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleDiv.cpp b/compiler/luci/import/src/Nodes/CircleDiv.cpp index 615c224d7..7ea1afd95 100644 --- a/compiler/luci/import/src/Nodes/CircleDiv.cpp +++ b/compiler/luci/import/src/Nodes/CircleDiv.cpp @@ -23,13 +23,7 @@ namespace luci bool CircleDivGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleDivGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleElu.cpp b/compiler/luci/import/src/Nodes/CircleElu.cpp index 919e95ee4..461da9517 100644 --- a/compiler/luci/import/src/Nodes/CircleElu.cpp +++ b/compiler/luci/import/src/Nodes/CircleElu.cpp @@ -25,14 +25,11 @@ namespace luci bool CircleEluGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; - if (outputs.size() != 1) - return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleEqual.cpp b/compiler/luci/import/src/Nodes/CircleEqual.cpp index 1db33b8ac..4909692b4 100644 --- a/compiler/luci/import/src/Nodes/CircleEqual.cpp +++ b/compiler/luci/import/src/Nodes/CircleEqual.cpp @@ -25,13 +25,10 @@ namespace luci bool CircleEqualGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); return tensors[inputs.at(0)]->type == tensors[inputs.at(1)]->type; diff --git a/compiler/luci/import/src/Nodes/CircleExp.cpp b/compiler/luci/import/src/Nodes/CircleExp.cpp index 2c031d6b3..64f18fbd4 100644 --- a/compiler/luci/import/src/Nodes/CircleExp.cpp +++ b/compiler/luci/import/src/Nodes/CircleExp.cpp @@ -25,10 +25,10 @@ namespace luci bool CircleExpGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; // input type check const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp index ab537c710..ee0fbdc7e 100644 --- a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp +++ b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp @@ -25,13 +25,10 @@ namespace luci bool CircleExpandDimsGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); return tensors[inputs.at(1)]->type == circle::TensorType_INT32; diff --git a/compiler/luci/import/src/Nodes/CircleFakeQuant.cpp b/compiler/luci/import/src/Nodes/CircleFakeQuant.cpp new file mode 100644 index 000000000..7cf40b225 --- /dev/null +++ b/compiler/luci/import/src/Nodes/CircleFakeQuant.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/Nodes/CircleFakeQuant.h" + +#include <luci/IR/Nodes/CircleFullyConnected.h> +#include <luci/IR/Nodes/CircleOutput.h> + +#include <loco.h> +#include <oops/UserExn.h> + +namespace luci +{ + +bool CircleFakeQuantGraphBuilder::validate(const ValidateArgs &args) const +{ + return GraphBuilder::validate(args, 1); +} + +CircleNode *CircleFakeQuantGraphBuilder::build_node(const circle::OperatorT &op, + const std::vector<CircleNode *> &inputs, + loco::Graph *graph) const +{ + auto *node = graph->nodes()->create<CircleFakeQuant>(); + node->inputs(inputs.at(0)); + + const auto *options = op.builtin_options.AsFakeQuantOptions(); + node->min(options->min); + node->max(options->max); + node->num_bits(options->num_bits); + node->narrow_range(options->narrow_range); + + return node; +} + +} // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleFill.cpp b/compiler/luci/import/src/Nodes/CircleFill.cpp index 95d5b876b..9aacddcbe 100644 --- a/compiler/luci/import/src/Nodes/CircleFill.cpp +++ b/compiler/luci/import/src/Nodes/CircleFill.cpp @@ -23,13 +23,7 @@ namespace luci bool CircleFillGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleFillGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleFloor.cpp b/compiler/luci/import/src/Nodes/CircleFloor.cpp index ce756b3b1..9651259c7 100644 --- a/compiler/luci/import/src/Nodes/CircleFloor.cpp +++ b/compiler/luci/import/src/Nodes/CircleFloor.cpp @@ -25,16 +25,8 @@ namespace luci bool CircleFloorGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 1) - return false; - if (outputs.size() != 1) - return false; - // TODO dtype check - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleFloorGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp index 55f385d60..ce329326a 100644 --- a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp +++ b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp @@ -25,19 +25,11 @@ namespace luci bool CircleFloorDivGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } - - if (outputs.size() != 1) - { - return false; - } + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor_in_0 = tensors.at(inputs.at(0)); const auto &tensor_in_1 = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp index 2101e417e..d8420a43c 100644 --- a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp +++ b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp @@ -25,13 +25,10 @@ namespace luci bool CircleFloorModGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 2) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); const auto &tensor_in_0 = tensors.at(inputs.at(0)); const auto &tensor_in_1 = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp index 17293ad7a..58750d79a 100644 --- a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp +++ b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp @@ -27,10 +27,7 @@ namespace luci bool CircleFullyConnectedGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 3) - return false; - - return true; + return GraphBuilder::validate(args, 3); } CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT &op, @@ -42,15 +39,6 @@ CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT node->weights(inputs.at(1)); node->bias(inputs.at(2)); // bias is optional - // TODO Find and move to appropriate place for setting optional input - if (auto bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias())) - { - // bias is not used for type inference, but node itself should have a type - bias->dtype(loco::DataType::FLOAT32); - - // bias is not used for shape inference - } - const auto *options = op.builtin_options.AsFullyConnectedOptions(); node->fusedActivationFunction(luci_actfunc(options->fused_activation_function)); node->weights_format(luci_weights_format(options->weights_format)); diff --git a/compiler/luci/import/src/Nodes/CircleGather.cpp b/compiler/luci/import/src/Nodes/CircleGather.cpp index 75447a38a..8317a3340 100644 --- a/compiler/luci/import/src/Nodes/CircleGather.cpp +++ b/compiler/luci/import/src/Nodes/CircleGather.cpp @@ -26,18 +26,14 @@ namespace luci bool CircleGatherGraphBuilder::validate(const ValidateArgs &args) const { + if (!GraphBuilder::validate(args, 2)) + return false; + const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; const auto *options = args.op.builtin_options.AsGatherOptions(); int32_t axis = options->axis; - if (inputs.size() != 2) - return false; - - if (outputs.size() != 1) - return false; - if (axis < 0) axis += inputs.size(); diff --git a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp index 981adbf63..a4bb26a10 100644 --- a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp +++ b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp @@ -27,15 +27,10 @@ namespace luci bool CircleGatherNdGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; auto &indices_tensor = args.reader.tensors()[inputs.at(1)]; if (!(indices_tensor->type == circle::TensorType::TensorType_INT32 || diff --git a/compiler/luci/import/src/Nodes/CircleGreater.cpp b/compiler/luci/import/src/Nodes/CircleGreater.cpp index 1ad0467e4..f9c00346c 100644 --- a/compiler/luci/import/src/Nodes/CircleGreater.cpp +++ b/compiler/luci/import/src/Nodes/CircleGreater.cpp @@ -30,17 +30,13 @@ bool CircleGreaterGraphBuilder::validate(const ValidateArgs &args) const { LOGGER(l); + if (!GraphBuilder::validate(args, 2)) + return false; + auto settings = luci::UserSettings::settings(); const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - - if (outputs.size() != 1) - return false; - const auto &tensors = args.reader.tensors(); if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type) diff --git a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp index 0ac63b017..e20038fd9 100644 --- a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp +++ b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp @@ -25,19 +25,11 @@ namespace luci bool CircleGreaterEqualGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } - - if (outputs.size() != 1) - { - return false; - } + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type) diff --git a/compiler/luci/import/src/Nodes/CircleIf.cpp b/compiler/luci/import/src/Nodes/CircleIf.cpp index db9ffe1cd..ffdbf0b79 100644 --- a/compiler/luci/import/src/Nodes/CircleIf.cpp +++ b/compiler/luci/import/src/Nodes/CircleIf.cpp @@ -70,69 +70,34 @@ bool CircleIfGraphBuilder::validate(const ValidateArgs &args) const * \- CircleIfOut --- Node --- */ -void CircleIfGraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const +CircleNode *CircleIfGraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); + uint32_t input_count = bna.op.inputs.size() - 1; + uint32_t output_count = bna.op.outputs.size(); - auto graph = context->graph(); + auto *node = bna.context->graph()->nodes()->create<CircleIf>(input_count, output_count); - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - uint32_t input_count = inputs.size() - 1; - uint32_t output_count = outputs.size(); - - // Create CircleIf - CircleIf *node = graph->nodes()->create<CircleIf>(input_count, output_count); - - node->cond(input_nodes[0]); + node->cond(bna.input_nodes[0]); for (uint32_t idx = 0; idx < input_count; ++idx) { - node->input(idx, input_nodes[idx + 1]); + node->input(idx, bna.input_nodes[idx + 1]); } - const auto *options = op.builtin_options.AsIfOptions(); + const auto *options = bna.op.builtin_options.AsIfOptions(); node->then_branch(options->then_subgraph_index); node->else_branch(options->else_subgraph_index); - assert(outputs.size() > 0); - { - // Lets use name of output 0 as If name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for If itself but to virtual outputs - } - - // Create virtual outputs of If - for (uint32_t n = 0; n < output_count; ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; + return node; +} - auto *nodeout = graph->nodes()->create<CircleIfOut>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); +CircleNode *CircleIfGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleIfOut>(); - nodeout->input(node); - nodeout->index(n); + nodeout->input(boa.node); + nodeout->index(boa.index); - context->nodefinder()->enroll(outputs[n], nodeout); - } + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp b/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp index 6349fd3b7..977b53406 100644 --- a/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp +++ b/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp @@ -25,12 +25,8 @@ namespace luci bool CircleInstanceNormGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 3) - return false; - // TODO check dtypes - - return true; + return GraphBuilder::validate(args, 3); } CircleNode *CircleInstanceNormGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp b/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp index e4fdc200c..7e1faedfb 100644 --- a/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp +++ b/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp @@ -25,20 +25,7 @@ namespace luci bool CircleL2NormalizeGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 1) - { - return false; - } - - if (outputs.size() != 1) - { - return false; - } - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleL2NormalizeGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp b/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp index 202d9d6fb..849c7c5ed 100644 --- a/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp +++ b/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp @@ -25,12 +25,8 @@ namespace luci bool CircleL2Pool2DGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - // TODO check dtypes - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleL2Pool2DGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp b/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp index ad4979f39..880fa6428 100644 --- a/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp +++ b/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp @@ -25,13 +25,7 @@ namespace luci bool CircleLeakyReluGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleLeakyReluGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleLess.cpp b/compiler/luci/import/src/Nodes/CircleLess.cpp index 506036908..f9b99bebe 100644 --- a/compiler/luci/import/src/Nodes/CircleLess.cpp +++ b/compiler/luci/import/src/Nodes/CircleLess.cpp @@ -25,19 +25,11 @@ namespace luci bool CircleLessGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } - - if (outputs.size() != 1) - { - return false; - } + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp index 9b4f934a5..bb1712137 100644 --- a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp +++ b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp @@ -25,19 +25,11 @@ namespace luci bool CircleLessEqualGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } - - if (outputs.size() != 1) - { - return false; - } + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type) diff --git a/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp b/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp index 0e32f62de..d03c47d12 100644 --- a/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp +++ b/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp @@ -25,16 +25,12 @@ namespace luci bool CircleLocalResponseNormalizationGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - // TODO do attribute checks - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleLocalResponseNormalizationGraphBuilder::build_node( - const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const + const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const { auto *node = graph->nodes()->create<CircleLocalResponseNormalization>(); node->input(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleLog.cpp b/compiler/luci/import/src/Nodes/CircleLog.cpp index 346fc43bb..26b575070 100644 --- a/compiler/luci/import/src/Nodes/CircleLog.cpp +++ b/compiler/luci/import/src/Nodes/CircleLog.cpp @@ -25,12 +25,10 @@ namespace luci bool CircleLogGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) - return false; - if (args.op.outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; // input type check // Must be one of bfloat16, half, float32, float64, complex64, complex128. // Currently circle supports half(float16), float32, float64, complex64. diff --git a/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp b/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp index ef69e868a..4361db691 100644 --- a/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp @@ -25,12 +25,8 @@ namespace luci bool CircleLogSoftmaxGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - // TODO do attribute checks - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleLogSoftmaxGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp index 7844da0f6..b13fc2735 100644 --- a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp @@ -25,11 +25,11 @@ namespace luci bool CircleLogicalAndGraphBuilder::validate(const ValidateArgs &args) const { - // Only BOOL type is allowed for inputs - const auto &inputs = args.op.inputs; - if (inputs.size() != 2) + if (!GraphBuilder::validate(args, 2)) return false; + // Only BOOL type is allowed for inputs + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); for (auto input : inputs) { diff --git a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp index 3758642e4..f68218349 100644 --- a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp @@ -25,7 +25,7 @@ namespace luci bool CircleLogicalNotGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; // Only BOOL type is allowed for the input diff --git a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp index 1b87e6f9c..8c9023dd3 100644 --- a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp @@ -25,7 +25,7 @@ namespace luci bool CircleLogicalOrGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) + if (!GraphBuilder::validate(args, 2)) return false; // Only BOOL type is allowed for inputs diff --git a/compiler/luci/import/src/Nodes/CircleLogistic.cpp b/compiler/luci/import/src/Nodes/CircleLogistic.cpp index 9606e19cd..0f92a9bb4 100644 --- a/compiler/luci/import/src/Nodes/CircleLogistic.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogistic.cpp @@ -25,13 +25,11 @@ namespace luci bool CircleLogisticGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) - return false; - const auto &outputs = args.op.outputs; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type) return false; diff --git a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp index a4a21a8b7..590a07f2d 100644 --- a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp +++ b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp @@ -25,15 +25,11 @@ namespace luci bool CircleMatrixDiagGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 1) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp index cf0313149..edd7d2ae2 100644 --- a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp +++ b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp @@ -25,15 +25,11 @@ namespace luci bool CircleMatrixSetDiagGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp b/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp index 4bca0f40b..5c03fff18 100644 --- a/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp +++ b/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleMaxPool2DGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleMaxPool2DGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleMean.cpp b/compiler/luci/import/src/Nodes/CircleMean.cpp index d8fa9a53d..7882f17fc 100644 --- a/compiler/luci/import/src/Nodes/CircleMean.cpp +++ b/compiler/luci/import/src/Nodes/CircleMean.cpp @@ -23,10 +23,7 @@ namespace luci bool CircleMeanGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleMeanGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp b/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp index e0ddd4c11..e40ce2249 100644 --- a/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp +++ b/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp @@ -25,12 +25,8 @@ namespace luci bool CircleMirrorPadGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - // TODO check others - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleMirrorPadGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleMul.cpp b/compiler/luci/import/src/Nodes/CircleMul.cpp index e3c4a7ee5..28421f8c4 100644 --- a/compiler/luci/import/src/Nodes/CircleMul.cpp +++ b/compiler/luci/import/src/Nodes/CircleMul.cpp @@ -23,13 +23,7 @@ namespace luci bool CircleMulGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleMulGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleNeg.cpp b/compiler/luci/import/src/Nodes/CircleNeg.cpp index a64a69560..9dd1458f4 100644 --- a/compiler/luci/import/src/Nodes/CircleNeg.cpp +++ b/compiler/luci/import/src/Nodes/CircleNeg.cpp @@ -24,11 +24,8 @@ namespace luci { bool CircleNegGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - // TODO Support type check - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleNegGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp index a4ad4a53d..d3d69506b 100644 --- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp +++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp @@ -61,63 +61,27 @@ bool CircleNonMaxSuppressionV4GraphBuilder::validate(const ValidateArgs &args) c * We will create multiple NonMasSuppressionV4Oout nodes to emulate this */ -void CircleNonMaxSuppressionV4GraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleNonMaxSuppressionV4GraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); - - auto graph = context->graph(); - - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleNonMaxSuppressionV4 - auto node = graph->nodes()->create<CircleNonMaxSuppressionV4>(); - node->boxes(input_nodes[0]); - node->scores(input_nodes[1]); - node->max_output_size(input_nodes[2]); - node->iou_threshold(input_nodes[3]); - node->score_threshold(input_nodes[4]); - - assert(outputs.size() == 2); - { - // Let's use name of output 0 as NonMaxSuppressionV4 name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for NonMaxSuppressionV4 itself but to virtual outputs - } - - // Create virtual outputs of NonMaxSuppressionV4 - for (size_t n = 0; n < outputs.size(); ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; - - auto *nodeout = graph->nodes()->create<CircleNonMaxSuppressionV4Out>(); - copy_tensor_attributes(output_tensor, nodeout); - - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); - - nodeout->input(node); - nodeout->index(n); - - context->nodefinder()->enroll(outputs[n], nodeout); - } + auto node = bna.context->graph()->nodes()->create<CircleNonMaxSuppressionV4>(); + + node->boxes(bna.input_nodes[0]); + node->scores(bna.input_nodes[1]); + node->max_output_size(bna.input_nodes[2]); + node->iou_threshold(bna.input_nodes[3]); + node->score_threshold(bna.input_nodes[4]); + + return node; +} + +CircleNode *CircleNonMaxSuppressionV4GraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleNonMaxSuppressionV4Out>(); + + nodeout->input(boa.node); + nodeout->index(boa.index); + + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp index 241dbf5ff..d797d4cb7 100644 --- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp +++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp @@ -63,64 +63,28 @@ bool CircleNonMaxSuppressionV5GraphBuilder::validate(const ValidateArgs &args) c * We will create multiple NonMasSuppressionV5Oout nodes to emulate this */ -void CircleNonMaxSuppressionV5GraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleNonMaxSuppressionV5GraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); - - auto graph = context->graph(); - - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleNonMaxSuppressionV5 - auto node = graph->nodes()->create<CircleNonMaxSuppressionV5>(); - node->boxes(input_nodes[0]); - node->scores(input_nodes[1]); - node->max_output_size(input_nodes[2]); - node->iou_threshold(input_nodes[3]); - node->score_threshold(input_nodes[4]); - node->soft_nms_sigma(input_nodes[5]); - - assert(outputs.size() == 3); - { - // Let's use name of output 0 as NonMaxSuppressionV5 name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for NonMaxSuppressionV5 itself but to virtual outputs - } - - // Create virtual outputs of NonMaxSuppressionV5 - for (size_t n = 0; n < outputs.size(); ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; - - auto *nodeout = graph->nodes()->create<CircleNonMaxSuppressionV5Out>(); - copy_tensor_attributes(output_tensor, nodeout); - - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); - - nodeout->input(node); - nodeout->index(n); - - context->nodefinder()->enroll(outputs[n], nodeout); - } + auto node = bna.context->graph()->nodes()->create<CircleNonMaxSuppressionV5>(); + + node->boxes(bna.input_nodes[0]); + node->scores(bna.input_nodes[1]); + node->max_output_size(bna.input_nodes[2]); + node->iou_threshold(bna.input_nodes[3]); + node->score_threshold(bna.input_nodes[4]); + node->soft_nms_sigma(bna.input_nodes[5]); + + return node; +} + +CircleNode *CircleNonMaxSuppressionV5GraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleNonMaxSuppressionV5Out>(); + + nodeout->input(boa.node); + nodeout->index(boa.index); + + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp index 77e986de1..a0b8f9e4f 100644 --- a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp +++ b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp @@ -25,19 +25,11 @@ namespace luci bool CircleNotEqualGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } - - if (outputs.size() != 1) - { - return false; - } + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type) diff --git a/compiler/luci/import/src/Nodes/CircleOneHot.cpp b/compiler/luci/import/src/Nodes/CircleOneHot.cpp index 69294e1ed..3952cc21a 100644 --- a/compiler/luci/import/src/Nodes/CircleOneHot.cpp +++ b/compiler/luci/import/src/Nodes/CircleOneHot.cpp @@ -26,17 +26,12 @@ namespace luci bool CircleOneHotGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - const auto *options = args.op.builtin_options.AsOneHotOptions(); - // Only 4 Input come refered from - if (inputs.size() != 4) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 4)) return false; + const auto &inputs = args.op.inputs; + const auto *options = args.op.builtin_options.AsOneHotOptions(); const auto &tensors = args.reader.tensors(); const auto &indices = tensors.at(inputs.at(0)); const auto &depth = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CirclePRelu.cpp b/compiler/luci/import/src/Nodes/CirclePRelu.cpp index c07920f7c..7c81f04bb 100644 --- a/compiler/luci/import/src/Nodes/CirclePRelu.cpp +++ b/compiler/luci/import/src/Nodes/CirclePRelu.cpp @@ -25,13 +25,7 @@ namespace luci bool CirclePReluGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CirclePReluGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CirclePad.cpp b/compiler/luci/import/src/Nodes/CirclePad.cpp index 999173b90..67dce6dee 100644 --- a/compiler/luci/import/src/Nodes/CirclePad.cpp +++ b/compiler/luci/import/src/Nodes/CirclePad.cpp @@ -25,12 +25,8 @@ namespace luci bool CirclePadGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - // TODO do attribute checks - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CirclePadGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CirclePadV2.cpp b/compiler/luci/import/src/Nodes/CirclePadV2.cpp index 493876e68..84a45722a 100644 --- a/compiler/luci/import/src/Nodes/CirclePadV2.cpp +++ b/compiler/luci/import/src/Nodes/CirclePadV2.cpp @@ -25,13 +25,7 @@ namespace luci bool CirclePadV2GraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 3) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 3); } CircleNode *CirclePadV2GraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CirclePow.cpp b/compiler/luci/import/src/Nodes/CirclePow.cpp index def012614..1d2d41607 100644 --- a/compiler/luci/import/src/Nodes/CirclePow.cpp +++ b/compiler/luci/import/src/Nodes/CirclePow.cpp @@ -25,13 +25,7 @@ namespace luci bool CirclePowGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CirclePowGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleRange.cpp b/compiler/luci/import/src/Nodes/CircleRange.cpp index 38dc44ed6..d3b5afc95 100644 --- a/compiler/luci/import/src/Nodes/CircleRange.cpp +++ b/compiler/luci/import/src/Nodes/CircleRange.cpp @@ -24,11 +24,8 @@ namespace luci { bool CircleRangeGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 3) - return false; - // TODO Support type check - return true; + return GraphBuilder::validate(args, 3); } CircleNode *CircleRangeGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleRank.cpp b/compiler/luci/import/src/Nodes/CircleRank.cpp index 12658b192..afebb9509 100644 --- a/compiler/luci/import/src/Nodes/CircleRank.cpp +++ b/compiler/luci/import/src/Nodes/CircleRank.cpp @@ -24,13 +24,7 @@ namespace luci { bool CircleRankGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleRankGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp index 21a821951..13205dd7a 100644 --- a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp +++ b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp @@ -23,13 +23,11 @@ namespace luci bool CircleReduceAnyGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 2) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor_0 = tensors.at(inputs.at(0)); const auto &tensor_1 = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp index 5f054586e..3549c1a18 100644 --- a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp +++ b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp @@ -23,12 +23,10 @@ namespace luci bool CircleReduceProdGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 2) - return false; - if (args.op.outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); const auto &tensor_1 = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CircleRelu.cpp b/compiler/luci/import/src/Nodes/CircleRelu.cpp index 8e1c32a3a..73b8ffee8 100644 --- a/compiler/luci/import/src/Nodes/CircleRelu.cpp +++ b/compiler/luci/import/src/Nodes/CircleRelu.cpp @@ -25,13 +25,7 @@ namespace luci bool CircleReluGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleReluGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleRelu6.cpp b/compiler/luci/import/src/Nodes/CircleRelu6.cpp index 0283d7350..ab957eda8 100644 --- a/compiler/luci/import/src/Nodes/CircleRelu6.cpp +++ b/compiler/luci/import/src/Nodes/CircleRelu6.cpp @@ -25,13 +25,7 @@ namespace luci bool CircleRelu6GraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleRelu6GraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp b/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp index 7f517bc0d..4987f3be2 100644 --- a/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp +++ b/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp @@ -25,15 +25,8 @@ namespace luci bool CircleReluN1To1GraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - // TODO check dtypes - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleReluN1To1GraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleReshape.cpp b/compiler/luci/import/src/Nodes/CircleReshape.cpp index 996ae9d20..401dff0fc 100644 --- a/compiler/luci/import/src/Nodes/CircleReshape.cpp +++ b/compiler/luci/import/src/Nodes/CircleReshape.cpp @@ -30,6 +30,19 @@ bool CircleReshapeGraphBuilder::validate(const ValidateArgs &args) const if (args.op.outputs.size() != 1) return false; + // for two inputs, check if type is S32 + if (args.op.inputs.size() == 2) + { + const auto &inputs = args.op.inputs; + const auto &tensors = args.reader.tensors(); + const auto &tensor_in = tensors.at(inputs.at(1)); + + // NOTE fix this if there is any other case + // TensorFlow lite and circle only supports S32 + if (tensor_in->type != circle::TensorType::TensorType_INT32) + return false; + } + return true; } @@ -53,6 +66,7 @@ static CircleNode *create_shape_node(const std::vector<int32_t> &shape, loco::Gr { shape_node->at<loco::DataType::S32>(i) = shape[i]; } + shape_node->name("Reshape/shape"); return shape_node; } @@ -73,6 +87,7 @@ CircleNode *CircleReshapeGraphBuilder::build_node(const circle::OperatorT &op, shape_node = graph->nodes()->create<CircleOutputDummy>(); shape_node->dtype(loco::DataType::S32); shape_node->rank(0); + shape_node->name("Reshape/dummy"); } } diff --git a/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp b/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp index 0fccb7b44..c751b245c 100644 --- a/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp +++ b/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp @@ -16,7 +16,6 @@ #include "luci/Import/Nodes/CircleResizeBilinear.h" -#include <luci/IR/Nodes/CircleConst.h> #include <luci/IR/Nodes/CircleResizeBilinear.h> namespace luci @@ -24,13 +23,7 @@ namespace luci bool CircleResizeBilinearGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleResizeBilinearGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp b/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp index 324323f59..df7517fe9 100644 --- a/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp +++ b/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp @@ -16,7 +16,6 @@ #include "luci/Import/Nodes/CircleResizeNearestNeighbor.h" -#include <luci/IR/Nodes/CircleConst.h> #include <luci/IR/Nodes/CircleResizeNearestNeighbor.h> namespace luci @@ -24,17 +23,11 @@ namespace luci bool CircleResizeNearestNeighborGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleResizeNearestNeighborGraphBuilder::build_node( - const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const + const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const { auto *node = graph->nodes()->create<CircleResizeNearestNeighbor>(); node->input(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp index ad11d4c63..2fbb7a87c 100644 --- a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp +++ b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp @@ -25,14 +25,11 @@ namespace luci bool CircleReverseSequenceGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor_in = tensors.at(inputs.at(0)); const auto &tensor_lengths = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp index e2e53bb4b..ca7653201 100644 --- a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp +++ b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp @@ -25,14 +25,11 @@ namespace luci bool CircleReverseV2GraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor_in = tensors.at(inputs.at(0)); const auto &tensor_axis = tensors.at(inputs.at(1)); diff --git a/compiler/luci/import/src/Nodes/CircleRound.cpp b/compiler/luci/import/src/Nodes/CircleRound.cpp index ad77f9f03..d13e0fafe 100644 --- a/compiler/luci/import/src/Nodes/CircleRound.cpp +++ b/compiler/luci/import/src/Nodes/CircleRound.cpp @@ -25,14 +25,11 @@ namespace luci bool CircleRoundGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 1) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; // Must be one of the following types // bfloat16, half (float16), float32, float64, complex64, complex128 // Currently, circle supports float16, float32, complex64 diff --git a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp index ae05fbbf9..a9ca90832 100644 --- a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp +++ b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp @@ -25,10 +25,10 @@ namespace luci bool CircleRsqrtGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; // Must be one of the following types // bfloat16, half (float16), float32, float64, complex64, complex128 // Currently, circle supports float16, float32, complex64 @@ -36,6 +36,8 @@ bool CircleRsqrtGraphBuilder::validate(const ValidateArgs &args) const const auto &tensor = tensors.at(inputs.at(0)); switch (tensor->type) { + case circle::TensorType_UINT8: + case circle::TensorType_INT16: case circle::TensorType_FLOAT16: case circle::TensorType_FLOAT32: case circle::TensorType_COMPLEX64: diff --git a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp index 7f86aeb74..f8c175110 100644 --- a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp +++ b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp @@ -25,10 +25,10 @@ namespace luci bool CircleScatterNdGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 3) + if (!GraphBuilder::validate(args, 3)) return false; + const auto &inputs = args.op.inputs; // indices must have the same type as shape const auto &tensors = args.reader.tensors(); diff --git a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp index fb84e5d52..bfa333e8d 100644 --- a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp +++ b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp @@ -25,13 +25,11 @@ namespace luci bool CircleSegmentSumGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 2) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor_in = tensors.at(inputs.at(0)); const auto &tensor_out = tensors.at(outputs[0]); diff --git a/compiler/luci/import/src/Nodes/CircleSelect.cpp b/compiler/luci/import/src/Nodes/CircleSelect.cpp index 1e649f1e0..36a5fa8a8 100644 --- a/compiler/luci/import/src/Nodes/CircleSelect.cpp +++ b/compiler/luci/import/src/Nodes/CircleSelect.cpp @@ -25,13 +25,10 @@ namespace luci bool CircleSelectGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 3) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 3)) return false; + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); if (tensor->type != circle::TensorType_BOOL) diff --git a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp index e6dd04de0..556c8fa33 100644 --- a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp +++ b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp @@ -25,13 +25,10 @@ namespace luci bool CircleSelectV2GraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 3) - return false; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 3)) return false; + const auto &inputs = args.op.inputs; const auto &tensors = args.reader.tensors(); const auto &condition = tensors.at(inputs.at(0)); if (condition->type != circle::TensorType_BOOL) diff --git a/compiler/luci/import/src/Nodes/CircleShape.cpp b/compiler/luci/import/src/Nodes/CircleShape.cpp index bd7dfc9d9..86c0bf59b 100644 --- a/compiler/luci/import/src/Nodes/CircleShape.cpp +++ b/compiler/luci/import/src/Nodes/CircleShape.cpp @@ -25,16 +25,8 @@ namespace luci bool CircleShapeGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - if (inputs.size() != 1) - return false; - if (outputs.size() != 1) - return false; - // TODO check shape, dtype - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleShapeGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleSin.cpp b/compiler/luci/import/src/Nodes/CircleSin.cpp index 4b245ef6b..22f461123 100644 --- a/compiler/luci/import/src/Nodes/CircleSin.cpp +++ b/compiler/luci/import/src/Nodes/CircleSin.cpp @@ -25,12 +25,10 @@ namespace luci bool CircleSinGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) - return false; - if (args.op.outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; // input type check const auto &tensors = args.reader.tensors(); const auto &tensor = tensors.at(inputs.at(0)); diff --git a/compiler/luci/import/src/Nodes/CircleSlice.cpp b/compiler/luci/import/src/Nodes/CircleSlice.cpp index 8601fbf21..4166040b3 100644 --- a/compiler/luci/import/src/Nodes/CircleSlice.cpp +++ b/compiler/luci/import/src/Nodes/CircleSlice.cpp @@ -27,14 +27,8 @@ namespace luci bool CircleSliceGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 3) - return false; - if (args.op.outputs.size() != 1) - return false; - // TODO check shapes and types - - return true; + return GraphBuilder::validate(args, 3); } CircleNode *CircleSliceGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleSoftmax.cpp b/compiler/luci/import/src/Nodes/CircleSoftmax.cpp index 0ef0b5418..e79914455 100644 --- a/compiler/luci/import/src/Nodes/CircleSoftmax.cpp +++ b/compiler/luci/import/src/Nodes/CircleSoftmax.cpp @@ -25,12 +25,8 @@ namespace luci bool CircleSoftmaxGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - // TODO do attribute checks - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleSoftmaxGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp b/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp index 8ccd55dc6..2152b65c9 100644 --- a/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp +++ b/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp @@ -27,13 +27,8 @@ namespace luci bool CircleSpaceToDepthGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) - return false; - // TODO do attribute checks - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleSpaceToDepthGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp b/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp index ac756b1f3..ce0688bb9 100644 --- a/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp +++ b/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleSparseToDenseGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 4) - return false; - - return true; + return GraphBuilder::validate(args, 4); } CircleNode *CircleSparseToDenseGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleSplit.cpp b/compiler/luci/import/src/Nodes/CircleSplit.cpp index 07b6cc939..d0a24aae3 100644 --- a/compiler/luci/import/src/Nodes/CircleSplit.cpp +++ b/compiler/luci/import/src/Nodes/CircleSplit.cpp @@ -58,62 +58,27 @@ bool CircleSplitGraphBuilder::validate(const ValidateArgs &args) const * \- CircleSplitOut --- FullyConnected --- */ -void CircleSplitGraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const +CircleNode *CircleSplitGraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); + auto node = bna.context->graph()->nodes()->create<CircleSplit>(); - auto graph = context->graph(); + node->split_dim(bna.input_nodes[0]); + node->input(bna.input_nodes[1]); - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); + const auto *options = bna.op.builtin_options.AsSplitOptions(); + node->num_split(options->num_splits); - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } + return node; +} - // Create CircleSplit - auto node = graph->nodes()->create<CircleSplit>(); - node->split_dim(input_nodes[0]); - node->input(input_nodes[1]); +CircleNode *CircleSplitGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleSplitOut>(); - const auto *options = op.builtin_options.AsSplitOptions(); - node->num_split(options->num_splits); + nodeout->input(boa.node); + nodeout->index(boa.index); - assert(outputs.size() > 0); - assert(int32_t(outputs.size()) == options->num_splits); - { - // Let's use name of output 0 as Split name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for Split itself but to virtual outputs - } - - // Create virtual outputs of Split - for (int32_t n = 0; n < options->num_splits; ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; - - auto *nodeout = graph->nodes()->create<CircleSplitOut>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); - - nodeout->input(node); - nodeout->index(n); - - context->nodefinder()->enroll(outputs[n], nodeout); - } + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleSplitV.cpp b/compiler/luci/import/src/Nodes/CircleSplitV.cpp index 7c6e83e17..76cbf7046 100644 --- a/compiler/luci/import/src/Nodes/CircleSplitV.cpp +++ b/compiler/luci/import/src/Nodes/CircleSplitV.cpp @@ -58,64 +58,30 @@ bool CircleSplitVGraphBuilder::validate(const ValidateArgs &args) const * \- CircleSplitVOut --- FullyConnected --- */ -void CircleSplitVGraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleSplitVGraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); - - auto graph = context->graph(); - - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleSplitV - auto node = graph->nodes()->create<CircleSplitV>(); - node->input(input_nodes[0]); - node->size_splits(input_nodes[1]); - node->split_dim(input_nodes[2]); - - const auto *options = op.builtin_options.AsSplitVOptions(); + auto node = bna.context->graph()->nodes()->create<CircleSplitV>(); + + node->input(bna.input_nodes[0]); + node->size_splits(bna.input_nodes[1]); + node->split_dim(bna.input_nodes[2]); + + const auto *options = bna.op.builtin_options.AsSplitVOptions(); node->num_split(options->num_splits); - assert(outputs.size() > 0); - assert(int32_t(outputs.size()) == options->num_splits); - { - // Let's use name of output 0 as Split name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for Split itself but to virtual outputs - } - - // Create virtual outputs of Split - for (int32_t n = 0; n < options->num_splits; ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; - - auto *nodeout = graph->nodes()->create<CircleSplitVOut>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); - - nodeout->input(node); - nodeout->index(n); - - context->nodefinder()->enroll(outputs[n], nodeout); - } + assert(int32_t(bna.op.outputs.size()) == options->num_splits); + + return node; +} + +CircleNode *CircleSplitVGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleSplitVOut>(); + + nodeout->input(boa.node); + nodeout->index(boa.index); + + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleSqrt.cpp b/compiler/luci/import/src/Nodes/CircleSqrt.cpp index c8beaee0d..b1fdf7996 100644 --- a/compiler/luci/import/src/Nodes/CircleSqrt.cpp +++ b/compiler/luci/import/src/Nodes/CircleSqrt.cpp @@ -25,10 +25,7 @@ namespace luci bool CircleSqrtGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleSqrtGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleSquare.cpp b/compiler/luci/import/src/Nodes/CircleSquare.cpp index b5ba048d7..7ff2b84e6 100644 --- a/compiler/luci/import/src/Nodes/CircleSquare.cpp +++ b/compiler/luci/import/src/Nodes/CircleSquare.cpp @@ -25,10 +25,10 @@ namespace luci bool CircleSquareGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; // Must be one of the following types // bfloat16, half (float16), float32, float64, complex64, complex128 // Currently, circle supports float16, float32, complex64 diff --git a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp index 6deae94c5..f4e193713 100644 --- a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp +++ b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp @@ -25,15 +25,11 @@ namespace luci bool CircleSquaredDifferenceGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; // Inputs must be one of the following types // bfloat16, half(float16), float32, float64, int32, int64, complex64, complex128 const auto &tensors = args.reader.tensors(); diff --git a/compiler/luci/import/src/Nodes/CircleSqueeze.cpp b/compiler/luci/import/src/Nodes/CircleSqueeze.cpp index 32792c266..d24d8166c 100644 --- a/compiler/luci/import/src/Nodes/CircleSqueeze.cpp +++ b/compiler/luci/import/src/Nodes/CircleSqueeze.cpp @@ -16,7 +16,6 @@ #include "luci/Import/Nodes/CircleSqueeze.h" -#include <luci/IR/Nodes/CircleConst.h> #include <luci/IR/Nodes/CircleSqueeze.h> namespace luci @@ -24,13 +23,7 @@ namespace luci bool CircleSqueezeGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleSqueezeGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp b/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp index 8f943a682..ca8259cac 100644 --- a/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp +++ b/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp @@ -27,14 +27,8 @@ namespace luci bool CircleStridedSliceGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 4) - return false; - if (args.op.outputs.size() != 1) - return false; - // TODO check shapes and types - - return true; + return GraphBuilder::validate(args, 4); } CircleNode *CircleStridedSliceGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleSub.cpp b/compiler/luci/import/src/Nodes/CircleSub.cpp index 9acf83d40..c3978f218 100644 --- a/compiler/luci/import/src/Nodes/CircleSub.cpp +++ b/compiler/luci/import/src/Nodes/CircleSub.cpp @@ -25,13 +25,7 @@ namespace luci bool CircleSubGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleSubGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleSum.cpp b/compiler/luci/import/src/Nodes/CircleSum.cpp index bd3cb6239..e348a62d9 100644 --- a/compiler/luci/import/src/Nodes/CircleSum.cpp +++ b/compiler/luci/import/src/Nodes/CircleSum.cpp @@ -23,10 +23,7 @@ namespace luci bool CircleSumGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleSumGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleTanh.cpp b/compiler/luci/import/src/Nodes/CircleTanh.cpp index 018f5701b..95625a0e4 100644 --- a/compiler/luci/import/src/Nodes/CircleTanh.cpp +++ b/compiler/luci/import/src/Nodes/CircleTanh.cpp @@ -25,13 +25,11 @@ namespace luci bool CircleTanhGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - if (inputs.size() != 1) - return false; - const auto &outputs = args.op.outputs; - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type) return false; diff --git a/compiler/luci/import/src/Nodes/CircleTile.cpp b/compiler/luci/import/src/Nodes/CircleTile.cpp index bc6f320ba..6da44130c 100644 --- a/compiler/luci/import/src/Nodes/CircleTile.cpp +++ b/compiler/luci/import/src/Nodes/CircleTile.cpp @@ -25,15 +25,11 @@ namespace luci bool CircleTileGraphBuilder::validate(const ValidateArgs &args) const { - auto inputs = args.op.inputs; - auto outputs = args.op.outputs; - - if (inputs.size() != 2) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 2)) return false; + auto inputs = args.op.inputs; + auto outputs = args.op.outputs; // Multiples (inputs.at(1)) must be one of the following types // int32, int64 const auto &tensors = args.reader.tensors(); diff --git a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp index f0677de86..49f858798 100644 --- a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp +++ b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp @@ -59,59 +59,24 @@ bool CircleTopKV2GraphBuilder::validate(const ValidateArgs &args) const * \- CircleTopKV2Out --- FullyConnected --- */ -void CircleTopKV2GraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleTopKV2GraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); - - auto graph = context->graph(); - - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleTopKV2 - auto node = graph->nodes()->create<CircleTopKV2>(); - node->input(input_nodes[0]); - node->k(input_nodes[1]); - - assert(outputs.size() == 2); - { - // Let's use name of output 0 as TopKV2 name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for TopKV2 itself but to virtual outputs - } - - // Create virtual outputs of TopKV2 - for (size_t n = 0; n < outputs.size(); ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; - - auto *nodeout = graph->nodes()->create<CircleTopKV2Out>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); - - nodeout->input(node); - nodeout->index(n); - - context->nodefinder()->enroll(outputs[n], nodeout); - } + auto node = bna.context->graph()->nodes()->create<CircleTopKV2>(); + + node->input(bna.input_nodes[0]); + node->k(bna.input_nodes[1]); + + return node; +} + +CircleNode *CircleTopKV2GraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleTopKV2Out>(); + + nodeout->input(boa.node); + nodeout->index(boa.index); + + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleTranspose.cpp b/compiler/luci/import/src/Nodes/CircleTranspose.cpp index cc3153085..01095239e 100644 --- a/compiler/luci/import/src/Nodes/CircleTranspose.cpp +++ b/compiler/luci/import/src/Nodes/CircleTranspose.cpp @@ -25,13 +25,7 @@ namespace luci bool CircleTransposeGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 2) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 2); } CircleNode *CircleTransposeGraphBuilder::build_node(const circle::OperatorT &op, diff --git a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp index c280faaf5..5a60e2f54 100644 --- a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp +++ b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp @@ -61,16 +61,15 @@ CircleNode *CircleTransposeConvGraphBuilder::build_node(const circle::OperatorT node->filter(inputs.at(1)); node->outBackprop(inputs.at(2)); if (inputs.size() == 3) - node->bias(graph->nodes()->create<CircleOutputExclude>()); - else - node->bias(inputs.at(3)); - - if (auto bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias())) { - // CircleOutputExclude doesn't need a type, but since all nodes must have a type, a dummy type - // is inserted. + auto *bias = graph->nodes()->create<CircleOutputExclude>(); + // CircleOutputExclude doesn't need a type, but since all nodes must have a type, + // a dummy type is inserted. bias->dtype(loco::DataType::FLOAT32); + node->bias(bias); } + else + node->bias(inputs.at(3)); const auto *options = op.builtin_options.AsTransposeConvOptions(); node->padding(luci_padding(options->padding)); diff --git a/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp b/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp index c41cf4def..d9cc3f8d0 100644 --- a/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp +++ b/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp @@ -25,14 +25,11 @@ namespace luci bool CircleUnidirectionalSequenceLSTMGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 24) - return false; - - return true; + return GraphBuilder::validate(args, 24); } CircleNode *CircleUnidirectionalSequenceLSTMGraphBuilder::build_node( - const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const + const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const { auto *node = graph->nodes()->create<CircleUnidirectionalSequenceLSTM>(); node->input(inputs.at(0)); @@ -59,16 +56,6 @@ CircleNode *CircleUnidirectionalSequenceLSTMGraphBuilder::build_node( node->forget_layer_norm_coefficients(inputs.at(21)); // Optional node->cell_layer_norm_coefficients(inputs.at(22)); // Optional node->output_layer_norm_coefficients(inputs.at(23)); // Optional - const std::vector<int32_t> optionals = {1, 5, 9, 10, 11, 12, 16, 17, 20, 21, 22, 23}; - for (auto optional : optionals) - { - if (auto inp = dynamic_cast<luci::CircleOutputExclude *>(node->arg(optional))) - { - // CircleOutputExclude doesn't need a type, but since all nodes must have a type, a dummy type - // is inserted. - inp->dtype(loco::DataType::FLOAT32); - } - } const auto *options = op.builtin_options.AsUnidirectionalSequenceLSTMOptions(); node->fusedActivationFunction(luci_actfunc(options->fused_activation_function)); diff --git a/compiler/luci/import/src/Nodes/CircleUnique.cpp b/compiler/luci/import/src/Nodes/CircleUnique.cpp index 5e79a2920..f6914c24a 100644 --- a/compiler/luci/import/src/Nodes/CircleUnique.cpp +++ b/compiler/luci/import/src/Nodes/CircleUnique.cpp @@ -35,55 +35,26 @@ bool CircleUniqueGraphBuilder::validate(const ValidateArgs &args) const return true; } -void CircleUniqueGraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleUniqueGraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); + auto node = bna.context->graph()->nodes()->create<CircleUnique>(); - auto graph = context->graph(); + node->input(bna.input_nodes[0]); - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); + const auto *options = bna.op.builtin_options.AsUniqueOptions(); + node->idx_out_type(luci_datatype(options->idx_out_type)); - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleUnique - auto node = graph->nodes()->create<CircleUnique>(); - node->input(input_nodes[0]); - - const auto *options = op.builtin_options.AsUniqueOptions(); - node->output_type(luci_datatype(options->idx_out_type)); - - assert(int32_t(outputs.size()) == 2); - // Let's use name of output 0 as Unique name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - - // Create virtual outputs of Unique - for (int32_t n = 0; n < 2; ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; + return node; +} - auto *nodeout = graph->nodes()->create<CircleUniqueOut>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); +CircleNode *CircleUniqueGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleUniqueOut>(); - nodeout->input(node); - nodeout->index(n); + nodeout->input(boa.node); + nodeout->index(boa.index); - context->nodefinder()->enroll(outputs[n], nodeout); - } + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleUnpack.cpp b/compiler/luci/import/src/Nodes/CircleUnpack.cpp index 9e7f3d3e1..9bfc76b57 100644 --- a/compiler/luci/import/src/Nodes/CircleUnpack.cpp +++ b/compiler/luci/import/src/Nodes/CircleUnpack.cpp @@ -88,64 +88,27 @@ bool CircleUnpackGraphBuilder::validate(const ValidateArgs &args) const * \- CircleUnpackOut --- FullyConnected --- */ -void CircleUnpackGraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleUnpackGraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); + auto node = bna.context->graph()->nodes()->create<CircleUnpack>(); - auto graph = context->graph(); + node->value(bna.input_nodes[0]); - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - // NOTE Unpack has only one input so running a loop is not necessary - // This is provided as a reference for other Ops as a reference - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleUnpack - CircleUnpack *node = graph->nodes()->create<CircleUnpack>(); - node->value(input_nodes[0]); - - const auto *options = op.builtin_options.AsUnpackOptions(); + const auto *options = bna.op.builtin_options.AsUnpackOptions(); node->num(options->num); node->axis(options->axis); - assert(outputs.size() > 0); - { - // Let's use name of output 0 as Unpack name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for Unpack itself but to virtual outputs - } - - // Create virtual outputs of Unpack - for (int32_t n = 0; n < options->num; ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; + return node; +} - auto *nodeout = graph->nodes()->create<CircleUnpackOut>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); +CircleNode *CircleUnpackGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleUnpackOut>(); - nodeout->input(node); - nodeout->index(n); + nodeout->input(boa.node); + nodeout->index(boa.index); - context->nodefinder()->enroll(outputs[n], nodeout); - } + return nodeout; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleWhere.cpp b/compiler/luci/import/src/Nodes/CircleWhere.cpp index f4c5f0c66..8e4f1a0c4 100644 --- a/compiler/luci/import/src/Nodes/CircleWhere.cpp +++ b/compiler/luci/import/src/Nodes/CircleWhere.cpp @@ -25,15 +25,11 @@ namespace luci bool CircleWhereGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 1) - return false; - - if (outputs.size() != 1) + if (!GraphBuilder::validate(args, 1)) return false; + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; const auto &tensors = args.reader.tensors(); const auto &tensor_condition = tensors.at(inputs.at(0)); const auto &tensor_out = tensors.at(outputs[0]); diff --git a/compiler/luci/import/src/Nodes/CircleWhile.cpp b/compiler/luci/import/src/Nodes/CircleWhile.cpp index aead25071..26147562f 100644 --- a/compiler/luci/import/src/Nodes/CircleWhile.cpp +++ b/compiler/luci/import/src/Nodes/CircleWhile.cpp @@ -58,7 +58,8 @@ bool CircleWhileGraphBuilder::validate(const ValidateArgs &args) const * \- CircleWhileOut --- Node --- */ -void CircleWhileGraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const +CircleNode *CircleWhileGraphBuilder::build(const circle::OperatorT &op, + GraphBuilderContext *context) const { assert(context != nullptr); @@ -118,6 +119,8 @@ void CircleWhileGraphBuilder::build(const circle::OperatorT &op, GraphBuilderCon context->nodefinder()->enroll(outputs[n], nodeout); } + + return node; } } // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleZerosLike.cpp b/compiler/luci/import/src/Nodes/CircleZerosLike.cpp index e60424def..ddb05e8a4 100644 --- a/compiler/luci/import/src/Nodes/CircleZerosLike.cpp +++ b/compiler/luci/import/src/Nodes/CircleZerosLike.cpp @@ -25,13 +25,7 @@ namespace luci bool CircleZerosLikeGraphBuilder::validate(const ValidateArgs &args) const { - if (args.op.inputs.size() != 1) - return false; - - if (args.op.outputs.size() != 1) - return false; - - return true; + return GraphBuilder::validate(args, 1); } CircleNode *CircleZerosLikeGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/PostImport.cpp b/compiler/luci/import/src/PostImport.cpp index f436b48e8..63b16bb95 100644 --- a/compiler/luci/import/src/PostImport.cpp +++ b/compiler/luci/import/src/PostImport.cpp @@ -130,7 +130,10 @@ private: namespace { /** - * @brief ValidateNodeProp will validate inter graph connections for each Nodes + * @brief ValidateNodeProp will validate inter graph connections for each Nodes. + * @note In here, only loco::GraphInput and loco::GraphOutput are validated, + * since this class is for checking inter graph connections. + * CircleNodes such as CircleInput and CircleOutput will be validated at later steps. */ class ValidateNodeProp final : public luci::CircleNodeMutableVisitor<void> { @@ -172,9 +175,19 @@ public: auto then_graph_output = then_graph_outputs->at(then_out->index()); auto else_graph_output = else_graph_outputs->at(else_out->index()); - if (!(*then_graph_output->shape() == *else_graph_output->shape())) + if (then_graph_output->shape()->rank() != else_graph_output->shape()->rank()) { - INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output shape mismatch ", idx); + INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output rank mismatch ", idx); + } + for (uint32_t i = 0; i < then_graph_output->shape()->rank(); ++i) + { + if (then_graph_output->shape()->dim(i).known() && + else_graph_output->shape()->dim(i).known() && + then_graph_output->shape()->dim(i).value() != + else_graph_output->shape()->dim(i).value()) + { + INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output dimension mismatch ", idx); + } } if (then_graph_output->dtype() != else_graph_output->dtype()) { @@ -231,18 +244,20 @@ public: auto cond_graph_input = cond_graph_inputs->at(cond_in->index()); auto body_graph_input = body_graph_inputs->at(body_in->index()); - if ((cond_in->rank() != body_in->rank())) + if (cond_graph_input->shape()->rank() != body_graph_input->shape()->rank()) { - INTERNAL_EXN_V("CircleWhile COND input and BODY input shape mismatch ", idx); + INTERNAL_EXN_V("CircleWhile COND input and BODY input rank mismatch ", idx); } - if (cond_in->rank() > 0 && body_in->rank() > 0) + for (uint32_t i = 0; i < cond_graph_input->shape()->rank(); ++i) { - if (!(*cond_graph_input->shape() == *body_graph_input->shape())) + if (cond_graph_input->shape()->dim(i).known() && + body_graph_input->shape()->dim(i).known() && + cond_graph_input->shape()->dim(i).value() != body_graph_input->shape()->dim(i).value()) { - INTERNAL_EXN_V("CircleWhile COND input and BODY input shape mismatch ", idx); + INTERNAL_EXN_V("CircleWhile COND input and BODY input dimension mismatch ", idx); } } - if (cond_in->dtype() != body_in->dtype()) + if (cond_graph_input->dtype() != body_graph_input->dtype()) { INTERNAL_EXN_V("CircleWhile COND input and BODY input type mismatch ", idx); } @@ -257,18 +272,20 @@ public: auto cond_graph_input = cond_graph_inputs->at(cond_in->index()); auto body_graph_output = body_graph_outputs->at(body_out->index()); - if ((cond_in->rank() != body_out->rank())) + if (cond_graph_input->shape()->rank() != body_graph_output->shape()->rank()) { - INTERNAL_EXN_V("CircleWhile COND input and BODY output shape mismatch ", idx); + INTERNAL_EXN_V("CircleWhile COND input and BODY output rank mismatch ", idx); } - if (cond_in->rank() > 0 && body_out->rank() > 0) + for (uint32_t i = 0; i < cond_graph_input->shape()->rank(); ++i) { - if (!(*cond_graph_input->shape() == *body_graph_output->shape())) + if (cond_graph_input->shape()->dim(i).known() && + body_graph_output->shape()->dim(i).known() && + cond_graph_input->shape()->dim(i).value() != body_graph_output->shape()->dim(i).value()) { - INTERNAL_EXN_V("CircleWhile COND input and BODY output shape mismatch ", idx); + INTERNAL_EXN_V("CircleWhile COND input and BODY output dimension mismatch ", idx); } } - if (cond_in->dtype() != body_out->dtype()) + if (cond_graph_input->dtype() != body_graph_output->dtype()) { INTERNAL_EXN_V("CircleWhile COND input and BODY output type mismatch ", idx); } |