diff options
author | Chunseok Lee <chunseok.lee@samsung.com> | 2020-04-23 14:45:49 +0900 |
---|---|---|
committer | Chunseok Lee <chunseok.lee@samsung.com> | 2020-04-23 14:45:49 +0900 |
commit | e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e (patch) | |
tree | 44a1a7951d168dd4370e13593ed03f4bc6d920c5 /compiler/exo/src | |
parent | 302e6564a7a76109e1178207e44e45a58631c477 (diff) | |
download | nnfw-e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e.tar.gz nnfw-e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e.tar.bz2 nnfw-e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e.zip |
Imported Upstream version 1.4.0upstream/1.4.0submit/tizen/20200423.054851
Diffstat (limited to 'compiler/exo/src')
162 files changed, 16666 insertions, 0 deletions
diff --git a/compiler/exo/src/Check.h b/compiler/exo/src/Check.h new file mode 100644 index 000000000..79dac50dd --- /dev/null +++ b/compiler/exo/src/Check.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CHECK_H__ +#define __CHECK_H__ + +#include <pepper/str.h> + +#include <stdexcept> +#include <cassert> +#include <iostream> + +// TODO Add macro for Release version + +#define EXO_ASSERT(condition, msg) \ + { \ + if (!(condition)) \ + { \ + std::cerr << "[assert failed] " << (msg) << ". " << std::endl; \ + assert((condition)); \ + } \ + } + +#endif // __CHECK_H__ diff --git a/compiler/exo/src/Circle/CircleExporter.cpp b/compiler/exo/src/Circle/CircleExporter.cpp new file mode 100644 index 000000000..797749090 --- /dev/null +++ b/compiler/exo/src/Circle/CircleExporter.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "exo/CircleExporter.h" + +#include "CircleExporterImpl.h" + +#include <stdex/Memory.h> + +#include <oops/InternalExn.h> + +#include <fstream> + +namespace exo +{ + +CircleExporter::CircleExporter(loco::Graph *graph) : _impl(stdex::make_unique<Impl>(graph)) +{ + // NOTHING TO DO +} + +CircleExporter::~CircleExporter() = default; + +void CircleExporter::dumpToFile(const char *path) const +{ + const char *ptr = _impl->getBufferPointer(); + const size_t size = _impl->getBufferSize(); + + if (!ptr) + INTERNAL_EXN("Graph was not serialized by FlatBuffer for some reason"); + + std::ofstream file(path, std::ofstream::binary); + file.write(ptr, size); +} + +} // namespace exo diff --git a/compiler/exo/src/Circle/CircleExporterImpl.cpp b/compiler/exo/src/Circle/CircleExporterImpl.cpp new file mode 100644 index 000000000..4cba33da1 --- /dev/null +++ b/compiler/exo/src/Circle/CircleExporterImpl.cpp @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleExporterImpl.h" + +#include "Convert.h" +#include "ExoOptimize.h" + +#include "CircleTensorExporter.h" +#include "CircleOperationExporter.h" +#include "CircleExporterUtils.h" + +#include "Log.h" +#include "Knob.h" + +#include <oops/InternalExn.h> + +#include <cassert> +#include <unordered_map> +#include <string> +#include <stdexcept> + +namespace +{ + +using namespace exo::circle_detail; + +void registerGraphInputTensors(loco::Graph *graph, SubGraphContext &ctx) +{ + for (uint32_t n = 0; n < graph->inputs()->size(); ++n) + { + auto node = loco::pull_node(graph, n); + assert(node != nullptr); + ctx._inputs.push_back(get_tensor_index(node)); + } +} + +void registerGraphOutputTensors(loco::Graph *graph, SubGraphContext &ctx) +{ + for (uint32_t n = 0; n < graph->outputs()->size(); ++n) + { + auto push = loco::push_node(graph, n); + assert(push != nullptr); + auto node = push->from(); + assert(node != nullptr); + ctx._outputs.push_back(get_tensor_index(node)); + } +} + +} // namespace + +namespace +{ + +using namespace circle; +using namespace flatbuffers; + +Offset<Vector<Offset<OperatorCode>>> +encodeOperatorCodes(FlatBufferBuilder &builder, std::unordered_map<OpCode, uint32_t> &opcodes, + std::unordered_map<OpCode, std::string> &custom_opcodes) +{ + std::vector<Offset<OperatorCode>> operator_codes_vec(opcodes.size()); + for (auto it : opcodes) + { + uint32_t idx = it.second; + if (it.first.opcode != BuiltinOperator_CUSTOM) + { + operator_codes_vec[idx] = CreateOperatorCode(builder, it.first.opcode); + } + else // custom op + { + auto opCode = it.first; + auto custom_code = custom_opcodes.find(opCode); + if (custom_code == custom_opcodes.end()) + INTERNAL_EXN("Cannot find code for customop even though opcode is BuiltinOperator_CUSTOM"); + + operator_codes_vec[idx] = + CreateOperatorCode(builder, it.first.opcode, builder.CreateString(custom_code->second)); + } + } + return builder.CreateVector(operator_codes_vec); +} + +} // namespace + +namespace exo +{ + +using namespace exo::circle_detail; +using namespace circle; +using namespace flatbuffers; + +CircleExporter::Impl::Impl(loco::Graph *graph) { exportGraph(graph); } + +::flatbuffers::Offset<::circle::SubGraph> +CircleExporter::Impl::exportSubgraph(SerializedModelData &gd) +{ + auto tensors = _builder.CreateVector(gd._tensors); + auto inputs = _builder.CreateVector(gd._inputs); + auto outputs = _builder.CreateVector(gd._outputs); + auto operators = _builder.CreateVector(gd._operators); + auto df = gd._data_format; + auto subgraph = CreateSubGraph(_builder, tensors, inputs, outputs, operators, df); + return subgraph; +} + +void CircleExporter::Impl::exportGraph(loco::Graph *graph) +{ + LOGGER(l); + + // IR-level conversion and optimization + { + convert_to_TFLNodes(graph); + set(Dialect::CIRCLE); + optimize(graph); + } + + _builder.Clear(); + + SerializedModelData gd; + + // This version is taken from comment in fbs + constexpr uint32_t version = 0; + + registerGraphIOName(graph, gd); + + // parse graph into SerializedModelData structure + exportOpDefinedTensors(graph, _builder, gd); + + // NOTE Invoke these register functions only after each node is annotated with its tensor_index + registerGraphInputTensors(graph, gd); + registerGraphOutputTensors(graph, gd); + + exportNodes(graph, _builder, gd); + + // encode operator codes + auto operator_codes = + encodeOperatorCodes(_builder, gd._operator_codes, gd._custom_operator_codes); + + // Subgraphs + Offset<SubGraph> subgraph = exportSubgraph(gd); + auto subgraphs = _builder.CreateVector(std::vector<Offset<SubGraph>>{subgraph}); + + // Description + std::string description_str = "nnpackage"; + auto description = _builder.CreateString(description_str); + + // create array of buffers + auto buffers = _builder.CreateVector(gd._buffers); + + // empty metadata + std::vector<int> metadata_buffer_vec; + auto metadata_buffer = _builder.CreateVector(metadata_buffer_vec); + + // Model + auto model_offset = CreateModel(_builder, version, operator_codes, subgraphs, description, + buffers, metadata_buffer); + FinishModelBuffer(_builder, model_offset); +} + +const char *CircleExporter::Impl::getBufferPointer() const +{ + return reinterpret_cast<const char *>(_builder.GetBufferPointer()); +} + +size_t CircleExporter::Impl::getBufferSize() const { return _builder.GetSize(); } + +} // namespace exo diff --git a/compiler/exo/src/Circle/CircleExporterImpl.h b/compiler/exo/src/Circle/CircleExporterImpl.h new file mode 100644 index 000000000..b1138fbad --- /dev/null +++ b/compiler/exo/src/Circle/CircleExporterImpl.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CIRCLE_EXPORTER_IMPL_H__ +#define __CIRCLE_EXPORTER_IMPL_H__ + +#include "exo/CircleExporter.h" +#include "circle_schema_generated.h" + +#include <loco.h> + +namespace exo +{ + +namespace circle_detail +{ + +struct SerializedModelData; + +} // namespace circle_detail + +using namespace circle_detail; + +/** + * internal implementation of interface exporter class + */ +class CircleExporter::Impl +{ +public: + Impl() = delete; + ~Impl() = default; + + explicit Impl(loco::Graph *graph); + + /** + * @return pointer to buffer with serialized graph + */ + const char *getBufferPointer() const; + + /** + * @return size of buffer with serialized graph + */ + size_t getBufferSize() const; + +private: + /** + * @brief create Subgraph using data stored in SerializedModelData + * @param gd information about serializer parts of model + * @return offset in buffer corresponding to serialized subgraph + */ + flatbuffers::Offset<circle::SubGraph> exportSubgraph(SerializedModelData &gd); + + /** + * @brief root function that writes graph into internal buffer + * @param graph + */ + void exportGraph(loco::Graph *graph); + +private: + flatbuffers::FlatBufferBuilder _builder; +}; + +} // namespace exo + +#endif // __CIRCLE_EXPORTER_IMPL_H__ diff --git a/compiler/exo/src/Circle/CircleExporterUtils.cpp b/compiler/exo/src/Circle/CircleExporterUtils.cpp new file mode 100644 index 000000000..12b204ce7 --- /dev/null +++ b/compiler/exo/src/Circle/CircleExporterUtils.cpp @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleExporterUtils.h" + +#include <oops/InternalExn.h> + +namespace exo +{ + +circle::ActivationFunctionType to_circle_actfunc(locoex::FusedActFunc func) +{ + switch (func) + { + case locoex::FusedActFunc::NONE: + return circle::ActivationFunctionType_NONE; + case locoex::FusedActFunc::RELU: + return circle::ActivationFunctionType_RELU; + case locoex::FusedActFunc::RELU6: + return circle::ActivationFunctionType_RELU6; + default: + INTERNAL_EXN_V("trying to convert unsupported locoex::FusedActFunc", oops::to_uint32(func)); + } +} + +} // namespace exo + +namespace exo +{ +namespace circle_detail +{ + +uint32_t SerializedModelData::registerBuiltinOpcode(circle::BuiltinOperator builtin_code) +{ + auto it = _operator_codes.find(OpCode{builtin_code}); + if (it != _operator_codes.end()) + { + return it->second; + } + auto idx = static_cast<uint32_t>(_operator_codes.size()); + _operator_codes.emplace(OpCode{builtin_code}, idx); + return idx; +} + +uint32_t SerializedModelData::registerCustomOpcode(const std::string &custom_op) +{ + circle::BuiltinOperator custom_code = circle::BuiltinOperator_CUSTOM; + auto idx = registerBuiltinOpcode(custom_code); + _custom_operator_codes.emplace(OpCode{custom_code}, custom_op); + return idx; +} + +circle::Padding getOpPadding(const loco::Padding2D *pad, const loco::Stride<2> *stride, + const ShapeDescription &ifm, const ShapeDescription &ofm) +{ + // VALID padding + if (pad->top() == 0 && pad->bottom() == 0 && pad->left() == 0 && pad->right() == 0) + return circle::Padding_VALID; + + // SAME padding + // + // For same padding, by definition, following equation should hold: + // O = floor((I - 1) / S) + 1 + // where input size I, output size O, stride S + // + // NOTE input and output 'feature' map are shape of NHWC + bool same_padding_criterion_1 = + (static_cast<uint32_t>(ofm._dims[1]) == (ifm._dims[1] - 1) / stride->vertical() + 1) && + (static_cast<uint32_t>(ofm._dims[2]) == (ifm._dims[2] - 1) / stride->horizontal() + 1); + + // For same padding, rear padding is same or bigger than front padding by at most 1 + bool same_padding_criterion_2 = + (pad->top() <= pad->bottom()) && (pad->bottom() <= pad->top() + 1) && + (pad->left() <= pad->right()) && (pad->right() <= pad->left() + 1); + + if (same_padding_criterion_1 && same_padding_criterion_2) + return circle::Padding_SAME; + + INTERNAL_EXN("Unsupported padding criteria"); +} + +circle::Padding getOpPadding(const locoex::Padding pad) +{ + if (pad == locoex::Padding::VALID) + return circle::Padding_VALID; + if (pad == locoex::Padding::SAME) + return circle::Padding_SAME; + + INTERNAL_EXN_V("Unsupported locoex::Padding", oops::to_uint32(pad)); +} + +void registerGraphIOName(loco::Graph *graph, SerializedModelData &gd) +{ + for (uint32_t in = 0; in < graph->inputs()->size(); ++in) + { + auto pull = loco::pull_node(graph, in); + auto name = graph->inputs()->at(in)->name(); + + gd._pull_to_name[pull] = name; + } + for (uint32_t out = 0; out < graph->outputs()->size(); ++out) + { + auto push = loco::push_node(graph, out); + auto name = graph->outputs()->at(out)->name(); + + gd._push_to_name[push] = name; + } + + // TODO set this value properly + gd._data_format = circle::DataFormat::DataFormat_CHANNELS_LAST; +} + +#include <stdex/Memory.h> + +#include <cassert> + +namespace +{ + +class TFLTensorIndexAnnotation final : public loco::NodeAnnotation +{ +public: + TFLTensorIndexAnnotation(const TFLTensorIndex &index) : _index{index} + { + // DO NOTHING + } + +public: + const TFLTensorIndex &index(void) const { return _index; } + +private: + TFLTensorIndex _index; +}; + +} // namespace + +void set_tensor_index(loco::Node *node, const TFLTensorIndex &tensor_id) +{ + assert(node->annot<TFLTensorIndexAnnotation>() == nullptr); + node->annot(stdex::make_unique<TFLTensorIndexAnnotation>(tensor_id)); +} + +TFLTensorIndex get_tensor_index(loco::Node *node) +{ + assert(node->annot<TFLTensorIndexAnnotation>() != nullptr); + return node->annot<TFLTensorIndexAnnotation>()->index(); +} + +} // namespace circle_detail +} // namespace exo diff --git a/compiler/exo/src/Circle/CircleExporterUtils.h b/compiler/exo/src/Circle/CircleExporterUtils.h new file mode 100644 index 000000000..fdd162bae --- /dev/null +++ b/compiler/exo/src/Circle/CircleExporterUtils.h @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CIRCLE_EXPORTER_UTILS_H__ +#define __CIRCLE_EXPORTER_UTILS_H__ + +#include "ExporterUtils.h" + +#include "circle_schema_generated.h" + +#include "Dialect/IR/TFLNodes.h" + +#include <loco.h> + +#include <unordered_map> + +namespace exo +{ +namespace circle_detail +{ + +struct OpCode +{ + circle::BuiltinOperator opcode; + + bool operator==(const OpCode &rhs) const { return opcode == rhs.opcode; } +}; + +} // namespace circle_detail +} // namespace exo + +namespace exo +{ + +circle::ActivationFunctionType to_circle_actfunc(locoex::FusedActFunc func); + +} // namespace exo + +namespace std +{ + +template <> struct hash<exo::circle_detail::OpCode> +{ + size_t operator()(const exo::circle_detail::OpCode &x) const { return hash<int>()(x.opcode); } +}; + +} // namespace std + +namespace exo +{ +namespace circle_detail +{ + +/** + * @breif Record the information of T/F Lite SubGraph and its mapping to loco + */ +struct SubGraphContext +{ + /// @brief SubGraph input tensor id + std::vector<int32_t> _inputs; + /// @brief SubGraph output tensor id + std::vector<int32_t> _outputs; + /// @DataFormat for SubGraph + circle::DataFormat _data_format{circle::DataFormat::DataFormat_CHANNELS_LAST}; +}; + +// Prerequisites for circle::Model object creation +struct SerializedModelData final : public SubGraphContext +{ + SerializedModelData() = default; + SerializedModelData(const SerializedModelData &) = delete; + + std::unordered_map<OpCode, uint32_t> _operator_codes; + std::unordered_map<OpCode, std::string> _custom_operator_codes; + std::vector<flatbuffers::Offset<circle::Operator>> _operators; + std::vector<flatbuffers::Offset<circle::Tensor>> _tensors; + std::vector<flatbuffers::Offset<circle::Buffer>> _buffers; + + // Graph input and output names + std::unordered_map<loco::Pull *, std::string> _pull_to_name; + std::unordered_map<loco::Push *, std::string> _push_to_name; + + /** + * @brief if opcode is not registered in table of opcodes add it + * @param builtin_code + * @return idx of opcode in table of opcodes (see schema) + */ + uint32_t registerBuiltinOpcode(circle::BuiltinOperator builtin_code); + uint32_t registerCustomOpcode(const std::string &custom_op); +}; + +circle::Padding getOpPadding(const loco::Padding2D *pad, const loco::Stride<2> *stride, + const ShapeDescription &ifm, const ShapeDescription &ofm); +circle::Padding getOpPadding(const locoex::Padding pad); + +/// @brief Register graph input and output names to SerializedModelData +void registerGraphIOName(loco::Graph *graph, SerializedModelData &gd); + +using TFLTensorIndex = int32_t; + +void set_tensor_index(loco::Node *node, const TFLTensorIndex &tensor_id); +TFLTensorIndex get_tensor_index(loco::Node *node); + +} // namespace circle_detail +} // namespace exo + +#endif // __TFL_EXPORTER_UTILS_H__ diff --git a/compiler/exo/src/Circle/CircleOperationExporter.cpp b/compiler/exo/src/Circle/CircleOperationExporter.cpp new file mode 100644 index 000000000..390e2ec99 --- /dev/null +++ b/compiler/exo/src/Circle/CircleOperationExporter.cpp @@ -0,0 +1,1228 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleOperationExporter.h" +#include "CircleExporterUtils.h" +#include "ShapeInference.h" + +#include "Dialect/IR/TFLNode.h" +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +#include "Dialect/IR/CircleNode.h" +#include "Dialect/IR/CircleNodes.h" +#include "Dialect/IR/CircleNodeVisitor.h" + +#include "Check.h" + +#include <loco/IR/CanonicalNode.h> +#include <loco/IR/CanonicalNodeVisitor.h> +#include <loco/Service/ShapeInference.h> +#include <locoex/COpCall.h> + +#include <oops/InternalExn.h> + +#include <flatbuffers/flexbuffers.h> + +using namespace flatbuffers; +using namespace circle; + +namespace +{ + +using namespace exo; +using namespace exo::circle_detail; + +class OperationExporter final : public locoex::TFLNodeMutableVisitor<void>, + public locoex::CircleNodeMutableVisitor<void>, + public loco::CanonicalNodeMutableVisitor<void> +{ +public: + OperationExporter(FlatBufferBuilder &fbb, SerializedModelData &ctx) : builder{fbb}, gd{ctx} + { + // DO NOTHING + } + +public: + // FOR TFLNodes + void visit(locoex::TFLAdd *) final; + void visit(locoex::TFLAveragePool2D *) final; + void visit(locoex::TFLConcatenation *) final; + void visit(locoex::TFLConst *) final{/* skip, everything is done in exportOpDefinedTensors */}; + void visit(locoex::TFLConv2D *) final; + void visit(locoex::TFLDepthwiseConv2D *) final; + void visit(locoex::TFLDiv *) final; + void visit(locoex::TFLFullyConnected *) final; + void visit(locoex::TFLMaximum *) final; + void visit(locoex::TFLMaxPool2D *) final; + void visit(locoex::TFLMean *) final; + void visit(locoex::TFLMul *) final; + void visit(locoex::TFLRelu *) final; + void visit(locoex::TFLRelu6 *) final; + // TODO TFLReshape + void visit(locoex::TFLRsqrt *) final; + // TODO TFLSoftmax + void visit(locoex::TFLSqrt *) final; + void visit(locoex::TFLSquaredDifference *) final; + void visit(locoex::TFLSub *) final; + // TODO TFLTanh + void visit(locoex::TFLTranspose *) final; + void visit(locoex::TFLTransposeConv *) final; + + // FOR CircleNodes + void visit(locoex::CircleInstanceNorm *) final; + + // FOR canonical nodes. These will be removed later + void visit(loco::ReLU *) final; + void visit(loco::ReLU6 *) final; + void visit(loco::Tanh *) final; + void visit(loco::Push *) final { /* DO NOTHING */} + void visit(loco::Pull *) final { /* DO NOTHING */} + void visit(loco::FeatureEncode *) final; + void visit(loco::FeatureDecode *) final; + void visit(loco::FilterEncode *) final; + void visit(loco::DepthwiseFilterEncode *) final; + void visit(loco::ConstGen *) final { /* skip, everything is done in exportOpDefinedTensors */} + void visit(loco::MaxPool2D *) final; + void visit(loco::AvgPool2D *) final; + void visit(loco::Conv2D *) final; + void visit(loco::TransposedConv2D *) final; + void visit(loco::DepthwiseConv2D *) final; + void visit(loco::TensorConcat *) final; + void visit(loco::TensorReduce *) final; + void visit(loco::TensorSoftmax *) final; + void visit(loco::BiasEncode *) final; + void visit(loco::TensorBiasAdd *) final; + void visit(loco::FeatureBiasAdd *) final; + void visit(loco::EltwiseAdd *) final; + void visit(loco::EltwiseMax *) final; + void visit(loco::EltwiseMul *) final; + void visit(loco::EltwiseSub *) final; + void visit(loco::EltwiseDiv *) final; + void visit(loco::EltwiseSqrt *) final; + void visit(loco::FixedReshape *) final; + void visit(loco::TensorBroadcast *) final; + void visit(loco::TensorConstantPad *) final; + + void visit(locoex::COpCall *); + +private: + /** + * @brief Exports TFLMaxPool2D or TFLAveragePool2D + * + * @note TFLPool2D should be one of TFLMaxPool2D or TFLAveragePool2D + */ + template <class TFLPool2D> + void export_pool_2d(TFLPool2D *node, circle::BuiltinOperator builtin_op); + +private: + FlatBufferBuilder &builder; + SerializedModelData &gd; +}; + +void OperationExporter::visit(locoex::TFLAdd *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_ADD); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateAddOptions(builder, to_circle_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_AddOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLAveragePool2D *node) +{ + export_pool_2d<locoex::TFLAveragePool2D>(node, circle::BuiltinOperator_AVERAGE_POOL_2D); +} + +void OperationExporter::visit(locoex::TFLConcatenation *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_CONCATENATION); + std::vector<int32_t> inputs_vec; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + + for (uint32_t i = 0; i < node->numValues(); ++i) + inputs_vec.push_back(get_tensor_index(node->values(i))); + + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateConcatenationOptions(builder, node->axis(), + to_circle_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_ConcatenationOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLConv2D *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_CONV_2D); + + // Make input, output and options for operator + std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->filter()), + get_tensor_index(node->bias())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + circle::Padding padding = getOpPadding(node->padding()); + auto options = CreateConv2DOptions(builder, padding, node->stride()->w(), node->stride()->h(), + to_circle_actfunc(node->fusedActivationFunction())); + + // Make CONV_2D operator + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_Conv2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLDepthwiseConv2D *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_DEPTHWISE_CONV_2D); + + // Make input, output and options for operator + std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->filter()), + get_tensor_index(node->bias())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + circle::Padding padding = getOpPadding(node->padding()); + auto options = CreateDepthwiseConv2DOptions(builder, padding, node->stride()->w(), + node->stride()->h(), node->depthMultiplier(), + to_circle_actfunc(node->fusedActivationFunction())); + + // Make DEPTHWISE_CONV_2D operator + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_DepthwiseConv2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLDiv *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_DIV); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateDivOptions(builder, to_circle_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_DivOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLFullyConnected *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_FULLY_CONNECTED); + + // Make input, output and options for operator + std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), + get_tensor_index(node->weights()), + get_tensor_index(node->bias())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = + CreateFullyConnectedOptions(builder, to_circle_actfunc(node->fusedActivationFunction())); + + // Make FULLY_CONNECTED operator + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_FullyConnectedOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLMaximum *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MAXIMUM); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateMaximumMinimumOptions(builder); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_MaximumMinimumOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLMaxPool2D *node) +{ + export_pool_2d<locoex::TFLMaxPool2D>(node, circle::BuiltinOperator_MAX_POOL_2D); +} + +void OperationExporter::visit(locoex::TFLMean *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MEAN); + std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), + get_tensor_index(node->reduction_indices())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateReducerOptions(builder, node->keep_dims()); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_ReducerOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLMul *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MUL); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateMulOptions(builder, to_circle_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_MulOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLRelu *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_RELU); + std::vector<int32_t> inputs_vec{get_tensor_index(node->features())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLRelu6 *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_RELU6); + std::vector<int32_t> inputs_vec{get_tensor_index(node->features())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +// TODO TFLReshape + +void OperationExporter::visit(locoex::TFLRsqrt *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_RSQRT); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +// TODO TFLSoftmax + +void OperationExporter::visit(locoex::TFLSqrt *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_SQRT); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLSquaredDifference *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_SQUARED_DIFFERENCE); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateSquaredDifferenceOptions(builder); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_SquaredDifferenceOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLSub *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_SUB); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateSubOptions(builder, to_circle_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_SubOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +// TODO TFLTanh + +void OperationExporter::visit(locoex::TFLTranspose *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_TRANSPOSE); + std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0)), get_tensor_index(node->arg(1))}; + std::vector<int32_t> outputs_vec{get_tensor_index(node)}; + + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateTransposeOptions(builder); + + auto op_offset = + CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions::BuiltinOptions_TransposeOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLTransposeConv *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_TRANSPOSE_CONV); + + // Make input, output and options for operator + std::vector<int32_t> inputs_vec{get_tensor_index(node->inputSizes()), + get_tensor_index(node->filter()), + get_tensor_index(node->outBackprop())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + circle::Padding padding = getOpPadding(node->padding()); + auto options = + CreateTransposeConvOptions(builder, padding, node->stride()->w(), node->stride()->h()); + + // Make TRANSPOSE_CONV operator + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_TransposeConvOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +template <class TFLPool2D> +void OperationExporter::export_pool_2d(TFLPool2D *node, circle::BuiltinOperator builtin_op) +{ + EXO_ASSERT(builtin_op == circle::BuiltinOperator_MAX_POOL_2D || + builtin_op == circle::BuiltinOperator_AVERAGE_POOL_2D, + "should be maxpool or avgpool"); + EXO_ASSERT(node->padding() != locoex::Padding::UNDEFINED, "Padding is not set"); + + uint32_t op_idx = gd.registerBuiltinOpcode(builtin_op); + std::vector<int32_t> inputs_vec{get_tensor_index(node->value())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + + circle::Padding padding = getOpPadding(node->padding()); + + auto options = CreatePool2DOptions(builder, padding, node->stride()->w(), node->stride()->h(), + node->filter()->w(), node->filter()->h(), + to_circle_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_Pool2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::CircleInstanceNorm *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_INSTANCE_NORM); + std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->gamma()), + get_tensor_index(node->beta())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateInstanceNormOptions(builder, node->epsilon(), + to_circle_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_InstanceNormOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::ReLU *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_RELU); + std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::ReLU6 *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_RELU6); + std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::Tanh *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_TANH); + std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::MaxPool2D *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MAX_POOL_2D); + std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + circle::Padding padding = getOpPadding( + node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node)); + auto options = CreatePool2DOptions(builder, padding, node->stride()->horizontal(), + node->stride()->vertical(), node->window()->horizontal(), + node->window()->vertical()); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_Pool2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::AvgPool2D *node) +{ + // Circle only support Valid convention of average pooling + assert(node->convention() == loco::AvgPool2D::Convention::Valid); + + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_AVERAGE_POOL_2D); + std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + circle::Padding padding = getOpPadding( + node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node)); + auto options = CreatePool2DOptions(builder, padding, node->stride()->horizontal(), + node->stride()->vertical(), node->window()->horizontal(), + node->window()->vertical()); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_Pool2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::Conv2D *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_CONV_2D); + + // Third input of CONV_2D of Circle should be bias. We will make (and register to gd) dummy zero + // bias. Bias would be rank 1, have size of output kernel count, and have all zero values, i.e. + // zero bias. + auto *ker = dynamic_cast<loco::FilterEncode *>(node->ker()); + assert(ker); + int32_t bias_vec_size = ShapeInference::get(ker)._dims[0]; // output kernel count + + auto bias_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{bias_vec_size}); + size_t raw_bias_vec_size = bias_vec_size * sizeof(int32_t); + + std::vector<float> bias_vec_data(bias_vec_size); // initialized as zero vector + + auto bias_vec_offset = + builder.CreateVector(reinterpret_cast<uint8_t *>(bias_vec_data.data()), raw_bias_vec_size); + + auto bias_buffer_offset = CreateBuffer(builder, bias_vec_offset); + + const auto bias_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + + gd._buffers.push_back(bias_buffer_offset); + + auto bias_tensor_id = static_cast<int32_t>(gd._tensors.size()); + auto name_offset = builder.CreateString("t_" + std::to_string(bias_tensor_id)); + + auto bias_tensor_offset = + CreateTensor(builder, bias_vec_shape_offset, TensorType_FLOAT32, bias_buffer_id, name_offset); + gd._tensors.push_back(bias_tensor_offset); + + // Make input, output and options for operator + std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm()), get_tensor_index(node->ker()), + bias_tensor_id}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + circle::Padding padding = getOpPadding( + node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node)); + auto options = CreateConv2DOptions(builder, padding, node->stride()->horizontal(), + node->stride()->vertical()); + + // Make CONV_2D operator + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_Conv2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::TransposedConv2D *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_TRANSPOSE_CONV); + + // TRANSPOSE_CONV's first input is output shape array. + const int32_t outshape_vec_size = 4; + auto outshape_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{outshape_vec_size}); + size_t raw_outshape_vec_size = outshape_vec_size * sizeof(int32_t); + + std::vector<int32_t> outshape_vec_data(outshape_vec_size); + { + // Copy inferred output shape of node + auto out_feature_shape = loco::shape_get(node).as<loco::FeatureShape>(); + + // Feature tensor in Circle is NHWC + outshape_vec_data.at(0) = out_feature_shape.count().value(); + outshape_vec_data.at(1) = out_feature_shape.height().value(); + outshape_vec_data.at(2) = out_feature_shape.width().value(); + outshape_vec_data.at(3) = out_feature_shape.depth().value(); + } + + auto outshape_vec_offset = builder.CreateVector( + reinterpret_cast<uint8_t *>(outshape_vec_data.data()), raw_outshape_vec_size); + + auto outshape_buffer_offset = CreateBuffer(builder, outshape_vec_offset); + + const auto outshape_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + + gd._buffers.push_back(outshape_buffer_offset); + + auto outshape_tensor_id = static_cast<int32_t>(gd._tensors.size()); + auto name_offset = builder.CreateString("t_" + std::to_string(outshape_tensor_id)); + + auto outshape_tensor_offset = CreateTensor(builder, outshape_vec_shape_offset, TensorType_INT32, + outshape_buffer_id, name_offset); + gd._tensors.push_back(outshape_tensor_offset); + + // Make input, output and options for operator + std::vector<int32_t> inputs_vec{outshape_tensor_id, get_tensor_index(node->ker()), + get_tensor_index(node->ifm())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + // NOTE input and output is inversed to use this function + circle::Padding padding = getOpPadding(node->pad(), node->stride(), ShapeInference::get(node), + ShapeInference::get(node->ifm())); + auto options = CreateTransposeConvOptions(builder, padding, node->stride()->horizontal(), + node->stride()->vertical()); + + // Make TRANSPOSE_CONV operator + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_TransposeConvOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::DepthwiseConv2D *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_DEPTHWISE_CONV_2D); + + // Third input of DEPTHWISE_CONV2D of Circle should be bias. We will make (and register to gd) + // dummy zero bias. Bias would be rank 1, have size of output kernel count, and have all zero + // values, i.e. zero bias. + auto *ker = dynamic_cast<loco::DepthwiseFilterEncode *>(node->ker()); + assert(ker); + + int32_t bias_vec_size = ShapeInference::get(ker)._dims[3]; // output_size(C*M) + auto bias_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{bias_vec_size}); + + size_t raw_bias_vec_size = bias_vec_size * sizeof(int32_t); + std::vector<float> bias_vec_data(bias_vec_size); + auto bias_vec_offset = + builder.CreateVector(reinterpret_cast<uint8_t *>(bias_vec_data.data()), raw_bias_vec_size); + + auto bias_buffer_offset = CreateBuffer(builder, bias_vec_offset); + + const auto bias_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + + gd._buffers.push_back(bias_buffer_offset); + + auto bias_tensor_id = static_cast<int32_t>(gd._tensors.size()); + auto name_offset = builder.CreateString("t_" + std::to_string(bias_tensor_id)); + + auto bias_tensor_offset = + CreateTensor(builder, bias_vec_shape_offset, TensorType_FLOAT32, bias_buffer_id, name_offset); + gd._tensors.push_back(bias_tensor_offset); + + std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm()), get_tensor_index(node->ker()), + bias_tensor_id}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + circle::Padding padding = getOpPadding( + node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node)); + + int32_t ifm_channel_size = ShapeInference::get(node->ifm())._dims[3]; + // multiplier = bias_vec_size(output_size)/ifm_channel_size + auto options = + CreateDepthwiseConv2DOptions(builder, padding, node->stride()->horizontal(), + node->stride()->vertical(), bias_vec_size / ifm_channel_size); + + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_DepthwiseConv2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::TensorReduce *node) +{ + uint32_t op_idx; + + switch (node->func()) + { + case loco::ReduceFunc::Mean: + op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MEAN); + break; + + // TODO Support more reduce type operation + default: + INTERNAL_EXN_V("Unsupported reduce type", oops::to_uint32(node->func())); + } + + // Create a vector for axes data + std::vector<int32_t> axes_vec; + auto rank = ShapeInference::get(node->input())._dims.size(); + for (uint32_t i = 0; i < rank; ++i) + if (node->axes()->defined(i)) + axes_vec.push_back(i); + + int32_t axes_vec_size = axes_vec.size(); + auto axes_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{axes_vec_size}); + + size_t raw_axes_vec_size = axes_vec_size * sizeof(int32_t); + auto axes_vec_offset = + builder.CreateVector(reinterpret_cast<uint8_t *>(axes_vec.data()), raw_axes_vec_size); + + auto axes_buffer_offset = CreateBuffer(builder, axes_vec_offset); + + const auto axes_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + + gd._buffers.push_back(axes_buffer_offset); + + auto axes_tensor_id = static_cast<int32_t>(gd._tensors.size()); + auto name_offset = builder.CreateString("t_" + std::to_string(axes_tensor_id)); + + auto axes_tensor_offset = + CreateTensor(builder, axes_vec_shape_offset, TensorType_INT32, axes_buffer_id, name_offset); + gd._tensors.push_back(axes_tensor_offset); + + std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), axes_tensor_id}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateReducerOptions(builder, true); // true is for keep_dims option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_ReducerOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::TensorSoftmax *node) +{ + // TODO Support when the input rank of TensorSoftmax is not 2 + assert(ShapeInference::get(node->input())._dims.size() == 2); + + // NOTE Circle only accepts axis when the value is last dimension + assert(node->axis() == ShapeInference::get(node->input())._dims.size() - 1); + + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_SOFTMAX); + std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateSoftmaxOptions(builder, 1.0f); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_SoftmaxOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +/// @brief Export given node into identity, i.e. CONCATENATION with one input +template <typename NodeT> +void exportIdentity(NodeT *node, FlatBufferBuilder &builder, SerializedModelData &gd) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_CONCATENATION); + std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0))}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateConcatenationOptions(builder); // use dummy 0 axis and NONE activation + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_ConcatenationOptions, options.Union()); + + gd._operators.push_back(op_offset); +} + +/// @brief Export loco nodes as TRANSPOSE +void exportAsTranspose(loco::Node *node, FlatBufferBuilder &builder, + std::vector<int32_t> &perm_vec_data, SerializedModelData &gd) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_TRANSPOSE); + + auto options = CreateTransposeOptions(builder); + + // Create constant tensor with perm vector + constexpr int perm_vec_size = 4; + assert(perm_vec_data.size() == perm_vec_size); + auto perm_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{perm_vec_size}); + constexpr size_t raw_perm_vec_size = perm_vec_size * sizeof(int32_t); + + auto perm_vec_offset = + builder.CreateVector(reinterpret_cast<uint8_t *>(perm_vec_data.data()), raw_perm_vec_size); + + auto perm_buffer_offset = CreateBuffer(builder, perm_vec_offset); + + const auto perm_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + + gd._buffers.push_back(perm_buffer_offset); + + auto perm_tensor_id = static_cast<int32_t>(gd._tensors.size()); + auto name_offset = builder.CreateString("t_" + std::to_string(perm_tensor_id)); + + auto perm_tensor_offset = + CreateTensor(builder, perm_vec_shape_offset, TensorType_INT32, perm_buffer_id, name_offset); + gd._tensors.push_back(perm_tensor_offset); + + // Create permutation node + + std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0)), perm_tensor_id}; + std::vector<int32_t> outputs_vec{get_tensor_index(node)}; + + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + + constexpr auto options_type = circle::BuiltinOptions::BuiltinOptions_TransposeOptions; + + auto transpose_offset = + CreateOperator(builder, op_idx, inputs, outputs, options_type, options.Union()); + gd._operators.push_back(transpose_offset); +} + +void OperationExporter::visit(loco::FeatureEncode *node) +{ + auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Feature> *>(node->encoder()); + auto perm = encoder->perm(); + + if (isNHWC(perm)) + { + // Note that Circle represents feature as NHWC + exportIdentity(node, builder, gd); + } + else + { + std::vector<int32_t> perm_vec_data(4); + perm_vec_data[0] = perm->axis(loco::FeatureAxis::Count); + perm_vec_data[1] = perm->axis(loco::FeatureAxis::Height); + perm_vec_data[2] = perm->axis(loco::FeatureAxis::Width); + perm_vec_data[3] = perm->axis(loco::FeatureAxis::Depth); + + exportAsTranspose(node, builder, perm_vec_data, gd); + } +} + +void OperationExporter::visit(loco::FeatureDecode *node) +{ + auto decoder = dynamic_cast<loco::PermutingDecoder<loco::Domain::Feature> *>(node->decoder()); + auto perm = decoder->perm(); + + if (isNHWC(perm)) + { + // Note that Circle represents feature as NHWC + exportIdentity(node, builder, gd); + } + else + { + std::vector<int32_t> perm_vec_data(4); + perm_vec_data[perm->axis(loco::FeatureAxis::Count)] = 0; + perm_vec_data[perm->axis(loco::FeatureAxis::Height)] = 1; + perm_vec_data[perm->axis(loco::FeatureAxis::Width)] = 2; + perm_vec_data[perm->axis(loco::FeatureAxis::Depth)] = 3; + + exportAsTranspose(node, builder, perm_vec_data, gd); + } +} + +void OperationExporter::visit(loco::FilterEncode *node) +{ + auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Filter> *>(node->encoder()); + auto perm = encoder->perm(); + + if (isNHWC(perm)) + { + // Note that Circle represents filter as NHWC + exportIdentity(node, builder, gd); + } + else + { + std::vector<int32_t> perm_vec_data(4); + // NOTE In Circle, all tensors means NHWC, so 0 = N, 1 = H, 2 = W, 3 = C + perm_vec_data[0] = perm->axis(loco::FilterAxis::Count); + perm_vec_data[1] = perm->axis(loco::FilterAxis::Height); + perm_vec_data[2] = perm->axis(loco::FilterAxis::Width); + perm_vec_data[3] = perm->axis(loco::FilterAxis::Depth); + + exportAsTranspose(node, builder, perm_vec_data, gd); + } +} + +void exportAsReshape(loco::Node *node, FlatBufferBuilder &builder, + std::vector<int32_t> &new_shape_vec, SerializedModelData &gd) +{ + // NOTE Circle currently follows TFLite for this. + // NOTE TFLite has two ways to get new shape paramter, + // one is by attribute 'new_shape' and the other is by input 'shape'. + // Therefore TFLite interpreter calculates Reshape operation correctly + // if one of them is valid. + // However, since NN runtime usually get new shape parameter by input 'shape', + // passing new shape only by attribute can cause some problems. + // Of course, the opposite situation can be occurred in the future. + // To prevent those problems, we pass new shape parameter not only by attribute + // but also by input. + + auto input_shape_shape_vec_offset = + builder.CreateVector(std::vector<int32_t>{(int32_t)new_shape_vec.size()}); + + size_t input_shape_vec_size = new_shape_vec.size() * sizeof(int32_t); + auto input_shape_input_vec_offset = + builder.CreateVector(reinterpret_cast<uint8_t *>(new_shape_vec.data()), input_shape_vec_size); + auto input_shape_buffer_offset = CreateBuffer(builder, input_shape_input_vec_offset); + + const auto input_shape_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + gd._buffers.push_back(input_shape_buffer_offset); + + auto input_shape_tensor_id = static_cast<int32_t>(gd._tensors.size()); + auto name_offset = builder.CreateString("t_" + std::to_string(input_shape_tensor_id)); + auto input_shape_tensor_offset = CreateTensor( + builder, input_shape_shape_vec_offset, TensorType_INT32, input_shape_buffer_id, name_offset); + gd._tensors.push_back(input_shape_tensor_offset); + + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_RESHAPE); + + std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0)), input_shape_tensor_id}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + + auto new_shape_vec_offset = builder.CreateVector(new_shape_vec); + auto options = CreateReshapeOptions(builder, new_shape_vec_offset); + + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_ReshapeOptions, options.Union()); + + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::DepthwiseFilterEncode *node) +{ + auto ker = node->input(); // [H, W, C, M] + + // Circle represents filter as [1, H, W, C*M] where M is multiplier. + std::vector<int32_t> new_shape_vec(4); + new_shape_vec[0] = 1; + new_shape_vec[1] = ShapeInference::get(ker)._dims[0]; + new_shape_vec[2] = ShapeInference::get(ker)._dims[1]; + new_shape_vec[3] = ShapeInference::get(ker)._dims[2] * ShapeInference::get(ker)._dims[3]; + + exportAsReshape(node, builder, new_shape_vec, gd); +} + +void OperationExporter::visit(loco::BiasAdd<loco::Domain::Tensor> *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_ADD); + std::vector<int32_t> inputs_vec{get_tensor_index(node->value()), get_tensor_index(node->bias())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateAddOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_AddOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::FeatureBiasAdd *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_ADD); + std::vector<int32_t> inputs_vec{get_tensor_index(node->value()), get_tensor_index(node->bias())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateAddOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_AddOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +/// @brief Export CONCATENATION of **TWO** tensors only +void OperationExporter::visit(loco::TensorConcat *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_CONCATENATION); + std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateConcatenationOptions(builder, node->axis()); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_ConcatenationOptions, options.Union()); + + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::BiasEncode *encode) { exportIdentity(encode, builder, gd); } + +void OperationExporter::visit(loco::EltwiseAdd *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_ADD); + std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateAddOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_AddOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::EltwiseMax *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MAXIMUM); + std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateMaximumMinimumOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_MaximumMinimumOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::EltwiseMul *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MUL); + std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateMulOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_MulOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::EltwiseSub *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_SUB); + std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateSubOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_SubOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::EltwiseDiv *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_DIV); + std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateDivOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_DivOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::EltwiseSqrt *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_SQRT); + std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::FixedReshape *node) +{ + std::vector<int32_t> new_shape_vec; + for (uint32_t axis = 0; axis < node->rank(); ++axis) + { + assert(node->dim(axis).known()); + new_shape_vec.push_back(node->dim(axis).value()); + } + + exportAsReshape(node, builder, new_shape_vec, gd); +} + +void OperationExporter::visit(loco::TensorBroadcast *) +{ + INTERNAL_EXN("loco graph has loco::TensorBroadcast, which should not exist in the graph"); +} + +void OperationExporter::visit(loco::TensorConstantPad *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_PAD); + + // make padding attribute an input + auto padding = node->padding(); + // get padding vector size + int32_t padding_vec_size = padding->rank(); + // get byte size of vector + size_t padding_vec_byte_size = padding_vec_size * sizeof(int32_t) * 2; // [rank, 2] + // create vector for data + std::vector<int32_t> padding_vec_data(padding_vec_size * 2); + // set data + for (int32_t i = 0; i < padding_vec_size; i++) + { + padding_vec_data.at(i * 2) = padding->front(i); + padding_vec_data.at(i * 2 + 1) = padding->back(i); + } + // create FlatBuffer vector + auto padding_vec_ptr = builder.CreateVector(reinterpret_cast<uint8_t *>(padding_vec_data.data()), + padding_vec_byte_size); + + // create buffer + auto padding_buffer_ptr = CreateBuffer(builder, padding_vec_ptr); + // get buffer id + const auto padding_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + + gd._buffers.push_back(padding_buffer_ptr); + + // create padding shape vector + auto padding_shape_vec_ptr = builder.CreateVector(std::vector<int32_t>{padding_vec_size, 2}); + // create tensor + auto padding_tensor_ptr = + CreateTensor(builder, padding_shape_vec_ptr, TensorType_INT32, padding_buffer_id); + // get tensor id + const auto padding_tensor_id = static_cast<int32_t>(gd._tensors.size()); + + gd._tensors.push_back(padding_tensor_ptr); + + std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), padding_tensor_id}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +inline flatbuffers::Offset<flatbuffers::Vector<uint8_t>> +CreateCOpCallOptions(flatbuffers::FlatBufferBuilder &fbb, locoex::COpCall *copCall) +{ + // read attrs in FlexBuffer format and pass them to FlatBuffer builder + flexbuffers::Builder flexbuf; + { + size_t map_start = flexbuf.StartMap(); + + // Note: among attrs of COpCall, 'op' and 'name' won't be included into tflite file + auto names = copCall->attr_names(); + for (auto name : names) + { + if (auto int_val = copCall->attr<locoex::COpAttrType::Int>(name)) + flexbuf.Int(name.c_str(), int_val->val()); + else if (auto float_val = copCall->attr<locoex::COpAttrType::Float>(name)) + flexbuf.Float(name.c_str(), float_val->val()); + else + // TODO Support more attribute types + INTERNAL_EXN_V("Unsupported dtype while writing flexbuffer for customop attr", name); + } + + flexbuf.EndMap(map_start); + flexbuf.Finish(); + } + + auto offset = fbb.CreateVector(flexbuf.GetBuffer()); + + return offset; +} + +void OperationExporter::visit(locoex::COpCall *call) +{ + // Registering this custom op name into tflite Operator Codes table + uint32_t op_idx = gd.registerCustomOpcode(call->op()); + + std::vector<int32_t> inputs_vec; + { + inputs_vec.resize(call->arity()); + for (uint32_t i = 0; i < call->arity(); i++) + inputs_vec[i] = get_tensor_index(call->arg(i)); + } + + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(call))}; + + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + + auto custom_options = CreateCOpCallOptions(builder, call); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + circle::BuiltinOptions_NONE, // builtin_options_type + 0, // built-in option + custom_options, // custom options + circle::CustomOptionsFormat_FLEXBUFFERS); + + gd._operators.push_back(op_offset); +} + +void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, + SerializedModelData &data) +{ + // TODO Use explicit tagging to prevent possible mistake + auto isNoOp = [](loco::Node *node) { + if (node->arity() == 1) + { + assert(node->arg(0) != nullptr); + return get_tensor_index(node) == get_tensor_index(node->arg(0)); + } + return false; + }; + + if (isNoOp(node)) + { + // Skip if a given node is marked as NoOp (op with no effect) before + return; + } + + if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node)) + { // TODO Consider removing this later + OperationExporter exporter{builder, data}; + canonical_node->accept(&exporter); + } + else if (auto tfl_node = dynamic_cast<locoex::TFLNode *>(node)) + { + OperationExporter exporter{builder, data}; + tfl_node->accept(&exporter); + } + else if (auto circle_node = dynamic_cast<locoex::CircleNode *>(node)) + { + OperationExporter exporter{builder, data}; + circle_node->accept(&exporter); + } + else if (dynamic_cast<locoex::COpNode *>(node)) + { + OperationExporter exporter{builder, data}; + exporter.visit(dynamic_cast<locoex::COpCall *>(node)); + } + else + { + INTERNAL_EXN("Node with unsupported dialect found"); + } +} + +} // namespace + +namespace exo +{ +namespace circle_detail +{ + +void exportNodes(loco::Graph *g, FlatBufferBuilder &builder, SerializedModelData &gd) +{ + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + { + exportNode(node, builder, gd); + } +} + +} // namespace circle_detail +} // namespace exo diff --git a/compiler/exo/src/Circle/CircleOperationExporter.h b/compiler/exo/src/Circle/CircleOperationExporter.h new file mode 100644 index 000000000..19dadbfd1 --- /dev/null +++ b/compiler/exo/src/Circle/CircleOperationExporter.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CIRCLE_OPERATION_EXPORTER_H__ +#define __CIRCLE_OPERATION_EXPORTER_H__ + +#include "CircleExporterUtils.h" + +#include <loco/IR/Graph.h> + +namespace exo +{ +namespace circle_detail +{ + +/** + * @brief create Operators corresponding to model nodes + * @param nodes container with nodes + * @param gd information about serializer parts of model + */ +void exportNodes(loco::Graph *g, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &gd); + +} // namespace circle_detail +} // namespace exo + +#endif // __CIRCLE_OPERATION_EXPORTER_H__ diff --git a/compiler/exo/src/Circle/CircleTensorExporter.cpp b/compiler/exo/src/Circle/CircleTensorExporter.cpp new file mode 100644 index 000000000..efceae55d --- /dev/null +++ b/compiler/exo/src/Circle/CircleTensorExporter.cpp @@ -0,0 +1,261 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleTensorExporter.h" +#include "CircleTypeInference.h" +#include "ShapeInference.h" + +// TODO Fix include style +#include "loco/IR/Algorithm.h" +#include "loco/IR/CanonicalNode.h" +#include "loco/IR/CanonicalNodeVisitor.h" +#include "loco/IR/DataTypeTraits.h" + +#include "Dialect/IR/TFLNodes.h" + +#include <oops/InternalExn.h> + +using namespace circle; +using namespace flatbuffers; + +namespace +{ + +using namespace exo; +using namespace exo::circle_detail; + +class TFLTensorInfo +{ +public: + TFLTensorInfo() = default; + +public: + void name(const std::string &name) { _name = name; } + const std::string &name(void) const { return _name; } + +public: + const circle::TensorType &dtype(void) const { return _dtype; } + void dtype(const circle::TensorType &dtype) { _dtype = dtype; } + + const ShapeDescription &shape(void) const { return _shape; } + void shape(const ShapeDescription &shape) { _shape = shape; } + +public: + locoex::TFLConst *tfl_content(void) const { return _tfl_content; } + void tfl_content(locoex::TFLConst *c) { _tfl_content = c; } + +private: + std::string _name; + + circle::TensorType _dtype; + ShapeDescription _shape; + + // TODO Find a better design + loco::ConstGen *_content = nullptr; // TODO deprecate + locoex::TFLConst *_tfl_content = nullptr; +}; + +using TFLTensorContext = std::vector<TFLTensorInfo>; + +struct NoOpDetector final : public loco::CanonicalNodeMutableVisitor<bool> +{ + bool visit(loco::BiasEncode *) final + { + // BiasEncode is always noop + return true; + } + + bool visit(loco::FilterEncode *node) final + { + auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Filter> *>(node->encoder()); + if (encoder != nullptr) + { + auto perm = encoder->perm(); + return isNHWC(perm); + } + return false; + } + + bool visit(loco::FeatureEncode *node) final + { + auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Feature> *>(node->encoder()); + if (encoder != nullptr) + { + auto perm = encoder->perm(); + return isNHWC(perm); + } + return false; + } + + bool visit(loco::FeatureDecode *node) final + { + auto decoder = dynamic_cast<loco::PermutingDecoder<loco::Domain::Feature> *>(node->decoder()); + if (decoder != nullptr) + { + auto perm = decoder->perm(); + return isNHWC(perm); + } + return false; + } + + // Return false by default + bool visit(loco::Node *) final { return false; } +}; + +bool isNoOp(loco::Node *node) +{ + if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node)) + { + NoOpDetector d; + return canonical_node->accept(&d); + } + return false; +} + +void allocateCircleTensor(loco::Node *node, TFLTensorContext &ctx) +{ + if (isNoOp(node)) + { + assert(node->arity() == 1 && node->arg(0) != nullptr); + set_tensor_index(node, get_tensor_index(node->arg(0))); + return; + } + + auto tensor_index = static_cast<TFLTensorIndex>(ctx.size()); + // TODO Use Graph-level metadata for Input & Output + auto tensor_name = "t_" + std::to_string(tensor_index); + + TFLTensorInfo tensor_info; + + tensor_info.name(tensor_name); + tensor_info.dtype(TypeInference::get(node)); + tensor_info.shape(ShapeInference::get(node)); + + tensor_info.tfl_content(dynamic_cast<locoex::TFLConst *>(node)); + + set_tensor_index(node, tensor_index); + + ctx.emplace_back(tensor_info); +} + +} // namespace + +namespace +{ + +flatbuffers::Offset<Vector<int32_t>> encodeShape(FlatBufferBuilder &builder, + const ShapeDescription &shape) +{ + assert(shape._rank_known && "unknown number of dimensions is not supported"); + return builder.CreateVector(shape._dims); +} + +flatbuffers::Offset<circle::Buffer> encodeOpBuffer(FlatBufferBuilder &builder) +{ + return CreateBuffer(builder); +} + +template <typename NodeT> +flatbuffers::Offset<circle::Buffer> encodeOpBuffer(FlatBufferBuilder &builder, NodeT *) +{ + return CreateBuffer(builder); +} + +template <loco::DataType DT> +flatbuffers::Offset<circle::Buffer> encodeOpBufferByDType(FlatBufferBuilder &builder, + locoex::TFLConst *c) +{ + using NativeType = typename loco::DataTypeImpl<DT>::Type; + + std::vector<NativeType> raw_data; + const uint32_t size = c->size<DT>(); + raw_data.reserve(size); + for (uint32_t i = 0; i < size; ++i) + { + raw_data.push_back(c->at<DT>(i)); + } + const size_t raw_size = size * sizeof(NativeType); + auto array_offset = builder.CreateVector(reinterpret_cast<uint8_t *>(raw_data.data()), raw_size); + return CreateBuffer(builder, array_offset); +} + +template <> +flatbuffers::Offset<circle::Buffer> encodeOpBuffer(FlatBufferBuilder &builder, locoex::TFLConst *c) +{ + if (c->dtype() == loco::DataType::FLOAT32) + { + return encodeOpBufferByDType<loco::DataType::FLOAT32>(builder, c); + } + else if (c->dtype() == loco::DataType::S32) + { + return encodeOpBufferByDType<loco::DataType::S32>(builder, c); + } + + INTERNAL_EXN_V("Unsupported datatype", oops::to_uint32(c->dtype())); +} + +} // namespace + +namespace exo +{ +namespace circle_detail +{ + +void exportOpDefinedTensor(const TFLTensorInfo &info, FlatBufferBuilder &builder, + SerializedModelData &gd) +{ + // Create and register output tensor shape + auto shape_offset = encodeShape(builder, info.shape()); + + // encode and register output tensor buffer + auto buffer = info.tfl_content() == nullptr ? encodeOpBuffer(builder) + : encodeOpBuffer(builder, info.tfl_content()); + + auto buffer_id = static_cast<uint32_t>(gd._buffers.size()); + gd._buffers.push_back(buffer); + + auto name_offset = builder.CreateString(info.name()); + auto tensor_offset = CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset, + /*quantization*/ 0, /*is_variable*/ false); + gd._tensors.push_back(tensor_offset); +} + +void exportOpDefinedTensors(loco::Graph *g, FlatBufferBuilder &builder, SerializedModelData &gd) +{ + TFLTensorContext tensor_ctx; + + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + { + allocateCircleTensor(node, tensor_ctx); + } + + // add one empty buffer + // note: this follows TFLite + // note: there's a comment in tflite fbs file + // - Note the 0th entry of this array must be an empty buffer (sentinel). + // - This is a convention so that tensors without a buffer can provide 0 as + // - their buffer. + auto buffer = encodeOpBuffer(builder); + gd._buffers.push_back(buffer); + + for (const auto &tensor_info : tensor_ctx) + { + exportOpDefinedTensor(tensor_info, builder, gd); + } +} + +} // namespace circle_detail +} // namespace exo diff --git a/compiler/exo/src/Circle/CircleTensorExporter.h b/compiler/exo/src/Circle/CircleTensorExporter.h new file mode 100644 index 000000000..39d8e1b86 --- /dev/null +++ b/compiler/exo/src/Circle/CircleTensorExporter.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CIRCLE_TENSOR_EXPORTER_H__ +#define __CIRCLE_TENSOR_EXPORTER_H__ + +#include "CircleExporterUtils.h" + +#include <loco/IR/Graph.h> + +#include <flatbuffers/flatbuffers.h> + +namespace exo +{ +namespace circle_detail +{ + +/** + * @brief create Tensors corresponding to results of all nodes in graph + * @param computational graph + * @param gd information about serialized parts of model + */ +void exportOpDefinedTensors(loco::Graph *g, flatbuffers::FlatBufferBuilder &builder, + SerializedModelData &gd); + +} // namespace circle_detail +} // namespace exo + +#endif // __CIRCLE_TENSOR_EXPORTER_H__ diff --git a/compiler/exo/src/Circle/CircleTypeInference.cpp b/compiler/exo/src/Circle/CircleTypeInference.cpp new file mode 100644 index 000000000..a1e92b884 --- /dev/null +++ b/compiler/exo/src/Circle/CircleTypeInference.cpp @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleTypeInference.h" + +#include "circle_schema_generated.h" + +#include "Dialect/Service/TFLTypeInferenceRule.h" +#include "Dialect/IR/TFLDialect.h" + +#include <loco/IR/CanonicalNode.h> +#include <loco/IR/CanonicalNodeVisitor.h> +#include <loco/IR/CanonicalDialect.h> +#include <loco/Service/TypeInference.h> + +#include <locoex/COpDialect.h> +#include <locoex/Service/COpTypeInference.h> + +#include <oops/InternalExn.h> + +#include <stdex/Memory.h> + +#include <stdexcept> +#include <type_traits> + +namespace +{ + +circle::TensorType translateLocoTypeToCircle(loco::DataType dtype) +{ + switch (dtype) + { + case loco::DataType::U8: + return circle::TensorType_UINT8; + // case loco::DataType::U16: unsupported + // case loco::DataType::U32: unsupported + // case loco::DataType::U64: unsupported + case loco::DataType::S8: + return circle::TensorType_INT8; + case loco::DataType::S16: + return circle::TensorType_INT16; + case loco::DataType::S32: + return circle::TensorType_INT32; + case loco::DataType::S64: + return circle::TensorType_INT64; + case loco::DataType::FLOAT16: + return circle::TensorType_FLOAT16; + case loco::DataType::FLOAT32: + return circle::TensorType_FLOAT32; + // case loco::DataType::FLOAT64: unsupported + default: + break; + } + + INTERNAL_EXN_V("Invalid loco dtype", oops::to_uint32(dtype)); +} + +} // namespace + +namespace exo +{ +namespace circle_detail +{ + +circle::TensorType TypeInference::get(loco::Node *node) +{ + assert(loco::dtype_known(node)); + return translateLocoTypeToCircle(loco::dtype_get(node)); +} + +} // namespace circle_detail +} // namespace exo diff --git a/compiler/exo/src/Circle/CircleTypeInference.h b/compiler/exo/src/Circle/CircleTypeInference.h new file mode 100644 index 000000000..9c1730233 --- /dev/null +++ b/compiler/exo/src/Circle/CircleTypeInference.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CIRCLE_TYPE_INFERENCE_H__ +#define __CIRCLE_TYPE_INFERENCE_H__ + +#include "CircleExporterUtils.h" + +#include <loco/IR/Nodes.h> + +namespace exo +{ +namespace circle_detail +{ + +/** + * @brief Get the type of each node as NodeAnnotation + * + * HOW TO USE + * + * TypeInference::get(g->nodes()->at(0)); + * TypeInference::get(g->nodes()->at(...)); + */ +struct TypeInference +{ + static circle::TensorType get(loco::Node *node); +}; + +} // namespace circle_detail +} // namespace exo + +#endif // __CIRCLE_TYPE_INFERENCE_H__ diff --git a/compiler/exo/src/Conversion/AvgPool2DConverter.cpp b/compiler/exo/src/Conversion/AvgPool2DConverter.cpp new file mode 100644 index 000000000..a95518ac6 --- /dev/null +++ b/compiler/exo/src/Conversion/AvgPool2DConverter.cpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "AvgPool2DConverter.h" + +#include "Dialect/IR/TFLNodes.h" + +#include "GraphBlock.h" +#include "Check.h" + +#include <loco.h> + +namespace exo +{ +/** + * @brief Converts loco::AvgPool2D to locoex::TFLAveragePool2D + * + * How it works: (note: ten->fea means input: tensor, output: feature) + * + * Before: + * Foo ---- FeatureEncode ---- AvgPool2D ---- FeatureDecode ---- Bar + * ten->ten ten->fea fea->fea fea->ten ten->ten + * + * After: AvgPool2D + * / + * Foo -- FeatureEncode - FeatureDecode - TFLAvgPool2D - FeatureEncode - FeatureDecode -- Bar + * ten->ten ten->fea fea->ten ten->ten ten->fea fea->ten ten->ten + * + * @note This method replaces AvgPool2D with "FeatureDecode -- TFLAvgPool2D -- FeatureEncode". + * Redundant nodes will be removed during transforms. + */ +bool AvgPool2DConverter::convert(loco::AvgPool2D *origin) +{ + auto *graph = origin->graph(); + + auto dec = make_feature_decode<FeatureLayout::NHWC>(origin->ifm()); + auto tfl_average = graph->nodes()->create<locoex::TFLAveragePool2D>(); + { + tfl_average->value(dec); + + // set attributes + tfl_average->stride()->w(origin->stride()->horizontal()); + tfl_average->stride()->h(origin->stride()->vertical()); + + tfl_average->filter()->w(origin->window()->horizontal()); + tfl_average->filter()->h(origin->window()->vertical()); + + auto pad = origin->pad(); + if (pad->bottom() == 0 && pad->top() == 0 && pad->left() == 0 && pad->right() == 0) + tfl_average->padding(locoex::Padding::VALID); + else + // TODO This is necessary, but not sufficient condition. More rigorous check required + tfl_average->padding(locoex::Padding::SAME); + + tfl_average->fusedActivationFunction(locoex::FusedActFunc::NONE); + } + auto enc = make_feature_encode<FeatureLayout::NHWC>(tfl_average); + + // replace canonical node + loco::replace(origin).with(enc); + origin->ifm(nullptr); + + return true; +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/AvgPool2DConverter.h b/compiler/exo/src/Conversion/AvgPool2DConverter.h new file mode 100644 index 000000000..f66d02eb6 --- /dev/null +++ b/compiler/exo/src/Conversion/AvgPool2DConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_AVGPOOL2D_CONVERTER__ +#define __CONVERSION_AVGPOOL2D_CONVERTER__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::AvgPool2D to locoex::TFLAveragePool2D + */ +class AvgPool2DConverter : public CanonicalNodeConverter<loco::AvgPool2D> +{ +public: + const char *name(void) const final { return "exo::AvgPool2DConverter"; } + +public: + bool convert(loco::AvgPool2D *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_AVGPOOL2D_CONVERTER__ diff --git a/compiler/exo/src/Conversion/CanonicalNodeConverter.cpp b/compiler/exo/src/Conversion/CanonicalNodeConverter.cpp new file mode 100644 index 000000000..4daf905f8 --- /dev/null +++ b/compiler/exo/src/Conversion/CanonicalNodeConverter.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CanonicalNodeConverter.h" + +// This file is to make sure compilation of "CanonicalNodeConverter.h" diff --git a/compiler/exo/src/Conversion/CanonicalNodeConverter.h b/compiler/exo/src/Conversion/CanonicalNodeConverter.h new file mode 100644 index 000000000..76f73d888 --- /dev/null +++ b/compiler/exo/src/Conversion/CanonicalNodeConverter.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_CANONICAL_NODE_CONVERTER_H__ +#define __CONVERSION_CANONICAL_NODE_CONVERTER_H__ + +#include "Convert.h" + +#include <loco.h> +#include <loco/IR/CanonicalDialect.h> +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to convert a canonical node to TFL node + * + * TODO Find a better name + */ +template <typename CanonicalType> class CanonicalNodeConverter : public logo::Pass +{ +public: + virtual const char *name(void) const { return nullptr; } + +public: + bool run(loco::Graph *graph); + +protected: + virtual bool convert(CanonicalType *node) = 0; +}; + +template <typename CanonicalType> +bool CanonicalNodeConverter<CanonicalType>::run(loco::Graph *graph) +{ + auto active_nodes = loco::active_nodes(loco::output_nodes(graph)); + bool changed = false; + + for (auto node : active_nodes) + { + // TODO Generalize this to all loco dialects + if (node->dialect() == loco::CanonicalDialect::get()) + { + auto the_node = dynamic_cast<CanonicalType *>(node); + if (the_node != nullptr) + { + if (convert(the_node)) + changed = true; + } + } + } + + return changed; +} + +} // namespace exo + +#endif //__CONVERSION_CANONICAL_NODE_CONVERTER_H__ diff --git a/compiler/exo/src/Conversion/ConstGenConverter.cpp b/compiler/exo/src/Conversion/ConstGenConverter.cpp new file mode 100644 index 000000000..b2e2b4bdb --- /dev/null +++ b/compiler/exo/src/Conversion/ConstGenConverter.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConstGenConverter.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Check.h" + +#include <loco.h> + +#include <oops/InternalExn.h> + +namespace exo +{ + +bool ConstGenConverter::convert(loco::ConstGen *constgen) +{ + auto *graph = constgen->graph(); + + auto tfl_const = graph->nodes()->create<locoex::TFLConst>(); + { + if (constgen->dtype() == loco::DataType::FLOAT32) + { + tfl_const->dtype(loco::DataType::FLOAT32); + + tfl_const->rank(constgen->rank()); + for (uint32_t axis = 0; axis < constgen->rank(); axis++) + tfl_const->dim(axis) = constgen->dim(axis); + + auto size = constgen->size<loco::DataType::FLOAT32>(); + tfl_const->size<loco::DataType::FLOAT32>(size); + + for (uint32_t i = 0; i < size; ++i) + { + tfl_const->at<loco::DataType::FLOAT32>(i) = constgen->at<loco::DataType::FLOAT32>(i); + } + } + else + INTERNAL_EXN_V("Unsupported DataType", oops::to_uint32(constgen->dtype())); + } + + loco::replace(constgen).with(tfl_const); + + return true; +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/ConstGenConverter.h b/compiler/exo/src/Conversion/ConstGenConverter.h new file mode 100644 index 000000000..613ccd0e6 --- /dev/null +++ b/compiler/exo/src/Conversion/ConstGenConverter.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_CONSTGEN_CONVERTER_H__ +#define __CONVERSION_CONSTGEN_CONVERTER_H__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +class ConstGenConverter : public CanonicalNodeConverter<loco::ConstGen> +{ +public: + const char *name(void) const final { return "exo::ConstGenConverter"; } + +public: + bool convert(loco::ConstGen *constgen) final; +}; + +} // namespace exo + +#endif // __CONVERSION_CONSTGEN_CONVERTER_H__ diff --git a/compiler/exo/src/Conversion/ConstGenConverter.test.cpp b/compiler/exo/src/Conversion/ConstGenConverter.test.cpp new file mode 100644 index 000000000..f7a577242 --- /dev/null +++ b/compiler/exo/src/Conversion/ConstGenConverter.test.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConstGenConverter.h" +#include "ReluConverter.h" + +#include "Dialect/IR/TFLNodes.h" +#include "TestGraph.h" +#include "TestHelper.h" + +#include <loco.h> + +#include <gtest/gtest.h> + +TEST(TFLConstGenConverterTest, ConstGen_Relu) +{ + exo::test::ExampleGraph<exo::test::ExampleGraphType::ConstGen_ReLU> g; + + // set constgen + { + g.constgen->dtype(loco::DataType::FLOAT32); + g.constgen->shape({2, 1}); + g.constgen->size<loco::DataType::FLOAT32>(2); + + g.constgen->at<loco::DataType::FLOAT32>(0) = 0.5; + g.constgen->at<loco::DataType::FLOAT32>(1) = -0.5; + } + + // let's convert + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::ConstGenConverter>(); + test_phase.add_pass<exo::ReluConverter>(); + + test_phase.run(g.graph()); + } + + auto tfl_const = exo::test::find_first_node_bytype<locoex::TFLConst>(g.graph()); + auto tfl_relu = exo::test::find_first_node_bytype<locoex::TFLRelu>(g.graph()); + + ASSERT_TRUE(tfl_const != nullptr and tfl_relu != nullptr); + ASSERT_TRUE(tfl_relu->features() == tfl_const); + + ASSERT_TRUE(tfl_const->rank() == g.constgen->rank()); + ASSERT_TRUE(tfl_const->dim(0) == g.constgen->dim(0)); + ASSERT_TRUE(tfl_const->dim(1) == g.constgen->dim(1)); + ASSERT_TRUE(tfl_const->at<loco::DataType::FLOAT32>(0) == + g.constgen->at<loco::DataType::FLOAT32>(0)); + ASSERT_TRUE(tfl_const->at<loco::DataType::FLOAT32>(1) == + g.constgen->at<loco::DataType::FLOAT32>(1)); +} diff --git a/compiler/exo/src/Conversion/Conv2DConverter.cpp b/compiler/exo/src/Conversion/Conv2DConverter.cpp new file mode 100644 index 000000000..c8120171d --- /dev/null +++ b/compiler/exo/src/Conversion/Conv2DConverter.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Conv2DConverter.h" + +#include "Dialect/IR/TFLNodes.h" + +#include "GraphBlock.h" +#include "Check.h" + +#include <loco.h> +#include <loco/Service/TypeInference.h> +#include <loco/Service/ShapeInference.h> + +namespace exo +{ +/** + * @brief Converts loco::Conv2D to locoex::TFLConv2D + * @note Because TFLConv2D accepts input and filter of loco::Domain::Tensor, + * loco::FeatureDecode and loco::FilterDecode will be inserted as an inputs + * to meet domain invariant. + * Please refer to the comment in AvgPool2DConvert. + */ +bool Conv2DConverter::convert(loco::Conv2D *origin) +{ + auto *graph = origin->graph(); + + assert(origin->ifm()); + assert(origin->ker()); + + auto tfl_conv2d = graph->nodes()->create<locoex::TFLConv2D>(); + { + tfl_conv2d->stride()->w(origin->stride()->horizontal()); + tfl_conv2d->stride()->h(origin->stride()->vertical()); + + auto pad = origin->pad(); + if (pad->bottom() == 0 && pad->top() == 0 && pad->left() == 0 && pad->right() == 0) + tfl_conv2d->padding(locoex::Padding::VALID); + else + // TODO This is necessary, but not sufficient condition. More rigorous check required + tfl_conv2d->padding(locoex::Padding::SAME); + + tfl_conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE); + } + + // let's create a new graph connection with tfl_conv2d + { + // input + auto feature_dec = make_feature_decode<FeatureLayout::NHWC>(origin->ifm()); + tfl_conv2d->input(feature_dec); + + // filter + auto filter_dec = make_filter_decode<FilterLayout::OHWI>(origin->ker()); + tfl_conv2d->filter(filter_dec); + + // bias + auto zero_const = graph->nodes()->create<locoex::TFLConst>(); + { + assert(loco::shape_known(origin)); + assert(loco::dtype_known(origin) && loco::dtype_get(origin) == loco::DataType::FLOAT32); + + auto output_depth = loco::shape_get(origin->ker()).as<loco::FilterShape>().count(); + + zero_const->dtype(loco::DataType::FLOAT32); + zero_const->rank(1); + zero_const->dim(0) = output_depth; + zero_const->size<loco::DataType::FLOAT32>(output_depth.value()); + for (uint32_t x = 0; x < output_depth.value(); x++) + zero_const->at<loco::DataType::FLOAT32>(x) = 0.0; + } + tfl_conv2d->bias(zero_const); + + // output + auto feature_enc = make_feature_encode<FeatureLayout::NHWC>(tfl_conv2d); + + // replace canonical node + loco::replace(origin).with(feature_enc); + origin->ifm(nullptr); + } + + return true; +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/Conv2DConverter.h b/compiler/exo/src/Conversion/Conv2DConverter.h new file mode 100644 index 000000000..95b3fbfae --- /dev/null +++ b/compiler/exo/src/Conversion/Conv2DConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_CONV2D_CONVERTER__ +#define __CONVERSION_CONV2D_CONVERTER__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::Conv2D to locoex::TFLConv2D + */ +class Conv2DConverter : public CanonicalNodeConverter<loco::Conv2D> +{ +public: + const char *name(void) const final { return "exo::Conv2DConverter"; } + +public: + bool convert(loco::Conv2D *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_CONV2D_CONVERTER__ diff --git a/compiler/exo/src/Conversion/DepthwiseConv2DConverter.cpp b/compiler/exo/src/Conversion/DepthwiseConv2DConverter.cpp new file mode 100644 index 000000000..5959fcc45 --- /dev/null +++ b/compiler/exo/src/Conversion/DepthwiseConv2DConverter.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "DepthwiseConv2DConverter.h" + +#include "Dialect/IR/TFLNodes.h" + +#include "GraphBlock.h" +#include "Check.h" + +#include <loco.h> +#include <loco/Service/TypeInference.h> +#include <loco/Service/ShapeInference.h> + +namespace exo +{ + +bool DepthwiseConv2DConverter::convert(loco::DepthwiseConv2D *origin) +{ + // Filter shape is required + if (not loco::shape_known(origin->ker())) + return false; + + auto filter_shape = loco::shape_get(origin->ker()).as<loco::DepthwiseFilterShape>(); + + if ((origin->ifm() == nullptr) or (origin->ker() == nullptr)) + return false; + + auto *graph = origin->graph(); + + auto tfl_dw_conv2d = graph->nodes()->create<locoex::TFLDepthwiseConv2D>(); + { + tfl_dw_conv2d->stride()->w(origin->stride()->horizontal()); + tfl_dw_conv2d->stride()->h(origin->stride()->vertical()); + + auto pad = origin->pad(); + if (pad->bottom() == 0 && pad->top() == 0 && pad->left() == 0 && pad->right() == 0) + tfl_dw_conv2d->padding(locoex::Padding::VALID); + else + // TODO This is necessary, but not sufficient condition. More rigorous check required + tfl_dw_conv2d->padding(locoex::Padding::SAME); + + tfl_dw_conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE); + + uint32_t multiplier = filter_shape.multiplier().value(); + EXO_ASSERT(multiplier < std::numeric_limits<int32_t>::max(), + "Multiplier is too big that casting may occur unintended behavior") + + tfl_dw_conv2d->depthMultiplier(static_cast<int32_t>(multiplier)); + } + + // let's create a new graph connection with tfl_dw_conv2d + { + // ifm --- feature_dec --- tfl_dw_conv2d + auto feature_dec = make_feature_decode<FeatureLayout::NHWC>(origin->ifm()); + tfl_dw_conv2d->input(feature_dec); + + // ker --- filter_dec(H x W x C x M) --- reshape(1 x H x W x CM) --- tfl_dw_conv2d + auto filter_dec = make_dw_filter_decode<DepthwiseFilterLayout::HWCM>(origin->ker()); + + auto reshape = graph->nodes()->create<locoex::TFLReshape>(); + reshape->tensor(filter_dec); + + int32_t new_shape[4] = { + 1, static_cast<int32_t>(filter_shape.height().value()), + static_cast<int32_t>(filter_shape.width().value()), + static_cast<int32_t>(filter_shape.depth().value() * filter_shape.multiplier().value())}; + locoex::set_new_shape(reshape, new_shape, 4); + + tfl_dw_conv2d->filter(reshape); + + // bias + auto zero_const = graph->nodes()->create<locoex::TFLConst>(); + { + assert(loco::shape_known(origin)); + assert(loco::dtype_known(origin) && loco::dtype_get(origin) == loco::DataType::FLOAT32); + + // bias size is C * M + uint32_t bias_size = filter_shape.depth().value() * filter_shape.multiplier().value(); + + zero_const->dtype(loco::DataType::FLOAT32); + zero_const->rank(1); + zero_const->dim(0) = bias_size; + zero_const->size<loco::DataType::FLOAT32>(bias_size); + for (uint32_t x = 0; x < bias_size; x++) + zero_const->at<loco::DataType::FLOAT32>(x) = 0.0; + } + tfl_dw_conv2d->bias(zero_const); + + // output + auto feature_enc = make_feature_encode<FeatureLayout::NHWC>(tfl_dw_conv2d); + + // replace canonical node + loco::replace(origin).with(feature_enc); + origin->ifm(nullptr); + } + + return true; +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/DepthwiseConv2DConverter.h b/compiler/exo/src/Conversion/DepthwiseConv2DConverter.h new file mode 100644 index 000000000..57cc01e5e --- /dev/null +++ b/compiler/exo/src/Conversion/DepthwiseConv2DConverter.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_DEPTHWISECONV2D_CONVERTER__ +#define __CONVERSION_DEPTHWISECONV2D_CONVERTER__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::DepthwiseConv2D to locoex::TFLDepthwiseConv2D and auxiliary + * + * + * <BEFORE> + * + * IFM -------- DepthwiseConv2D --- Out + * [Feature] / [Feature] + * / + * KER ------- + * [DWFilter] + * + * + * <AFTER> + * TFLConst (bias) --------------------------- + * \ + * IFM ------ FeatureDecode ------------------ TFLDepthwiseConv2D --- FeatureEncode --- Out + * [Feature] [Tensor] / [Tensor] [Feature] + * / + * KER ------- DepthwiseFilterDecode --- TFLReshape + * [DWFilter] [Tensor / H W C M] [Tensor / 1 H W CM] + * + */ +class DepthwiseConv2DConverter : public CanonicalNodeConverter<loco::DepthwiseConv2D> +{ +public: + const char *name(void) const final { return "exo::DepthwiseConv2DConverter"; } + +public: + bool convert(loco::DepthwiseConv2D *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_DEPTHWISECONV2D_CONVERTER__ diff --git a/compiler/exo/src/Conversion/EltwiseAddConverter.cpp b/compiler/exo/src/Conversion/EltwiseAddConverter.cpp new file mode 100644 index 000000000..557f47944 --- /dev/null +++ b/compiler/exo/src/Conversion/EltwiseAddConverter.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "EltwiseAddConverter.h" + +#include "EltwiseBinaryConverter.h" + +namespace exo +{ + +bool EltwiseAddConverter::convert(loco::EltwiseAdd *origin) +{ + return EltwiseBinaryConvert<loco::EltwiseAdd, locoex::TFLAdd>(origin); +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/EltwiseAddConverter.h b/compiler/exo/src/Conversion/EltwiseAddConverter.h new file mode 100644 index 000000000..97e1071b5 --- /dev/null +++ b/compiler/exo/src/Conversion/EltwiseAddConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_ELTWISEADD_CONVERTER_H__ +#define __CONVERSION_ELTWISEADD_CONVERTER_H__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::EltwiseAdd to TFLAdd + */ +class EltwiseAddConverter : public CanonicalNodeConverter<loco::EltwiseAdd> +{ +public: + const char *name(void) const final { return "exo::EltwiseAddConverter"; } + +public: + bool convert(loco::EltwiseAdd *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_ELTWISEADD_CONVERTER_H__ diff --git a/compiler/exo/src/Conversion/EltwiseBinaryConverter.h b/compiler/exo/src/Conversion/EltwiseBinaryConverter.h new file mode 100644 index 000000000..095da9e5c --- /dev/null +++ b/compiler/exo/src/Conversion/EltwiseBinaryConverter.h @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_ELTWISEBINARY_CONVERTER_H__ +#define __CONVERSION_ELTWISEBINARY_CONVERTER_H__ + +#include "GraphBlock.h" +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" + +#include <loco/IR/Nodes.h> + +#include <loco/Service/ShapeInference.h> + +namespace +{ + +template <class ELTWISEBIN, class TFLBIN> +class EltwiseBinInputHandler : public exo::InputHandler<ELTWISEBIN, TFLBIN> +{ +public: + void handover(ELTWISEBIN *origin, TFLBIN *replacer) override + { + assert(origin && replacer); + replacer->x(origin->lhs()); + replacer->y(origin->rhs()); + } + + std::vector<loco::Node *> getInputsToConvert(ELTWISEBIN *origin) override + { + assert(origin); + std::vector<loco::Node *> inputs({origin->lhs(), origin->rhs()}); + return inputs; + } + + void set(TFLBIN *replacer, std::vector<loco::Node *> &to) override + { + assert(to.size() == 2); + + replacer->x(to.at(0)); + replacer->y(to.at(1)); + } + + void nullify(ELTWISEBIN *origin) override + { + assert(origin); + origin->lhs(nullptr); + origin->rhs(nullptr); + } +}; + +template <class TFLBIN> void init_fused_act_func(TFLBIN *); + +template <> inline void init_fused_act_func(locoex::TFLAdd *node) +{ + node->fusedActivationFunction(locoex::FusedActFunc::NONE); +} + +template <> inline void init_fused_act_func(locoex::TFLMul *node) +{ + node->fusedActivationFunction(locoex::FusedActFunc::NONE); +} + +template <> inline void init_fused_act_func(locoex::TFLSub *node) +{ + node->fusedActivationFunction(locoex::FusedActFunc::NONE); +} + +template <> inline void init_fused_act_func(locoex::TFLDiv *node) +{ + node->fusedActivationFunction(locoex::FusedActFunc::NONE); +} + +} // namespace + +namespace exo +{ + +template <class ELTWISEBIN, class TFLBIN> bool EltwiseBinaryConvert(ELTWISEBIN *origin) +{ + EltwiseBinInputHandler<ELTWISEBIN, TFLBIN> input_handler; + exo::DomainConverter<ELTWISEBIN, TFLBIN> domain_converter; + + auto tfl_node = domain_converter.template convert<FeatureLayout::NHWC>(origin, input_handler); + + if (tfl_node == nullptr) + return false; + + init_fused_act_func(tfl_node); + + return true; +} + +} // namespace exo + +#endif // __CONVERSION_ELTWISEBINARY_CONVERTER_H__ diff --git a/compiler/exo/src/Conversion/EltwiseDivConverter.cpp b/compiler/exo/src/Conversion/EltwiseDivConverter.cpp new file mode 100644 index 000000000..dc8eae461 --- /dev/null +++ b/compiler/exo/src/Conversion/EltwiseDivConverter.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "EltwiseDivConverter.h" + +#include "EltwiseBinaryConverter.h" + +namespace exo +{ + +bool EltwiseDivConverter::convert(loco::EltwiseDiv *origin) +{ + return EltwiseBinaryConvert<loco::EltwiseDiv, locoex::TFLDiv>(origin); +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/EltwiseDivConverter.h b/compiler/exo/src/Conversion/EltwiseDivConverter.h new file mode 100644 index 000000000..06b2d685b --- /dev/null +++ b/compiler/exo/src/Conversion/EltwiseDivConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_ELTWISEDIV_CONVERTER_H__ +#define __CONVERSION_ELTWISEDIV_CONVERTER_H__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::EltwiseDiv to TFLDiv + */ +class EltwiseDivConverter : public CanonicalNodeConverter<loco::EltwiseDiv> +{ +public: + const char *name(void) const final { return "exo::EltwiseDivConverter"; } + +public: + bool convert(loco::EltwiseDiv *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_ELTWISEDIV_CONVERTER_H__ diff --git a/compiler/exo/src/Conversion/EltwiseMaxConverter.cpp b/compiler/exo/src/Conversion/EltwiseMaxConverter.cpp new file mode 100644 index 000000000..dd7d34440 --- /dev/null +++ b/compiler/exo/src/Conversion/EltwiseMaxConverter.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "EltwiseMaxConverter.h" + +#include "GraphBlock.h" +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" + +#include <loco/Service/ShapeInference.h> + +namespace +{ + +class EltwiseMaxInputHandler : public exo::InputHandler<loco::EltwiseMax, locoex::TFLMaximum> +{ +public: + void handover(loco::EltwiseMax *origin, locoex::TFLMaximum *replacer) override + { + replacer->x(origin->lhs()); + replacer->y(origin->rhs()); + } + + std::vector<loco::Node *> getInputsToConvert(loco::EltwiseMax *origin) override + { + std::vector<loco::Node *> inputs({origin->lhs(), origin->rhs()}); + return inputs; + } + + void set(locoex::TFLMaximum *replacer, std::vector<loco::Node *> &to) override + { + assert(to.size() == 2); + + replacer->x(to.at(0)); + replacer->y(to.at(1)); + } + + void nullify(loco::EltwiseMax *origin) override + { + assert(origin); + origin->lhs(nullptr); + origin->rhs(nullptr); + } +}; + +} // namespace + +namespace exo +{ + +bool EltwiseMaxConverter::convert(loco::EltwiseMax *origin) +{ + EltwiseMaxInputHandler input_handler; + exo::DomainConverter<loco::EltwiseMax, locoex::TFLMaximum> domain_converter; + + auto tfl_new = domain_converter.convert<FeatureLayout::NHWC>(origin, input_handler); + + return (tfl_new != nullptr); +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/EltwiseMaxConverter.h b/compiler/exo/src/Conversion/EltwiseMaxConverter.h new file mode 100644 index 000000000..708745419 --- /dev/null +++ b/compiler/exo/src/Conversion/EltwiseMaxConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_ELTWISEMAX_CONVERTER_H__ +#define __CONVERSION_ELTWISEMAX_CONVERTER_H__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::EltwiseMax to TFLMaximum + */ +class EltwiseMaxConverter : public CanonicalNodeConverter<loco::EltwiseMax> +{ +public: + const char *name(void) const final { return "exo::EltwiseMaxConverter"; } + +public: + bool convert(loco::EltwiseMax *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_ELTWISEMAX_CONVERTER_H__ diff --git a/compiler/exo/src/Conversion/EltwiseMulConverter.cpp b/compiler/exo/src/Conversion/EltwiseMulConverter.cpp new file mode 100644 index 000000000..f7a4b8298 --- /dev/null +++ b/compiler/exo/src/Conversion/EltwiseMulConverter.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "EltwiseMulConverter.h" + +#include "EltwiseBinaryConverter.h" + +namespace exo +{ + +bool EltwiseMulConverter::convert(loco::EltwiseMul *origin) +{ + return EltwiseBinaryConvert<loco::EltwiseMul, locoex::TFLMul>(origin); +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/EltwiseMulConverter.h b/compiler/exo/src/Conversion/EltwiseMulConverter.h new file mode 100644 index 000000000..4f73484c0 --- /dev/null +++ b/compiler/exo/src/Conversion/EltwiseMulConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_ELTWISEMUL_CONVERTER_H__ +#define __CONVERSION_ELTWISEMUL_CONVERTER_H__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::EltwiseMul to TFLMul + */ +class EltwiseMulConverter : public CanonicalNodeConverter<loco::EltwiseMul> +{ +public: + const char *name(void) const final { return "exo::EltwiseMulConverter"; } + +public: + bool convert(loco::EltwiseMul *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_ELTWISEMUL_CONVERTER_H__ diff --git a/compiler/exo/src/Conversion/EltwiseSqrtConverter.cpp b/compiler/exo/src/Conversion/EltwiseSqrtConverter.cpp new file mode 100644 index 000000000..6dead7dc6 --- /dev/null +++ b/compiler/exo/src/Conversion/EltwiseSqrtConverter.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "EltwiseSqrtConverter.h" + +#include "GraphBlock.h" +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" + +#include <loco/Service/ShapeInference.h> + +namespace +{ + +class EltwiseSqrtInputHandler : public exo::InputHandler<loco::EltwiseSqrt, locoex::TFLSqrt> +{ +public: + void handover(loco::EltwiseSqrt *origin, locoex::TFLSqrt *replacer) override + { + replacer->x(origin->input()); + } + + std::vector<loco::Node *> getInputsToConvert(loco::EltwiseSqrt *origin) override + { + std::vector<loco::Node *> inputs({origin->input()}); + return inputs; + } + + void set(locoex::TFLSqrt *replacer, std::vector<loco::Node *> &to) override + { + assert(to.size() == 1); + + replacer->x(to.at(0)); + } + + void nullify(loco::EltwiseSqrt *origin) override { origin->input(nullptr); } +}; + +} // namespace + +namespace exo +{ + +bool EltwiseSqrtConverter::convert(loco::EltwiseSqrt *origin) +{ + EltwiseSqrtInputHandler input_handler; + exo::DomainConverter<loco::EltwiseSqrt, locoex::TFLSqrt> domain_converter; + + auto tfl_new = domain_converter.convert<FeatureLayout::NHWC>(origin, input_handler); + + return (tfl_new != nullptr); +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/EltwiseSqrtConverter.h b/compiler/exo/src/Conversion/EltwiseSqrtConverter.h new file mode 100644 index 000000000..5ee3185ff --- /dev/null +++ b/compiler/exo/src/Conversion/EltwiseSqrtConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ELTWISE_SQRT_CONVERTER_H__ +#define __ELTWISE_SQRT_CONVERTER_H__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::EltwiseSqrt to TFLSqrt + */ +class EltwiseSqrtConverter : public CanonicalNodeConverter<loco::EltwiseSqrt> +{ +public: + const char *name(void) const final { return "exo::EltwiseSqrtConverter"; } + +public: + bool convert(loco::EltwiseSqrt *origin) final; +}; + +} // namespace exo + +#endif // __ELTWISE_SQRT_CONVERTER_H__ diff --git a/compiler/exo/src/Conversion/EltwiseSubConverter.cpp b/compiler/exo/src/Conversion/EltwiseSubConverter.cpp new file mode 100644 index 000000000..5647c47a2 --- /dev/null +++ b/compiler/exo/src/Conversion/EltwiseSubConverter.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "EltwiseSubConverter.h" + +#include "EltwiseBinaryConverter.h" + +namespace exo +{ + +bool EltwiseSubConverter::convert(loco::EltwiseSub *origin) +{ + return EltwiseBinaryConvert<loco::EltwiseSub, locoex::TFLSub>(origin); +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/EltwiseSubConverter.h b/compiler/exo/src/Conversion/EltwiseSubConverter.h new file mode 100644 index 000000000..d61b76ec0 --- /dev/null +++ b/compiler/exo/src/Conversion/EltwiseSubConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_ELTWISESUB_CONVERTER_H__ +#define __CONVERSION_ELTWISESUB_CONVERTER_H__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::EltwiseSub to TFLSub + */ +class EltwiseSubConverter : public CanonicalNodeConverter<loco::EltwiseSub> +{ +public: + const char *name(void) const final { return "exo::EltwiseSubConverter"; } + +public: + bool convert(loco::EltwiseSub *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_ELTWISESUB_CONVERTER_H__ diff --git a/compiler/exo/src/Conversion/FeatureBiasAddConverter.cpp b/compiler/exo/src/Conversion/FeatureBiasAddConverter.cpp new file mode 100644 index 000000000..b9aaf140b --- /dev/null +++ b/compiler/exo/src/Conversion/FeatureBiasAddConverter.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FeatureBiasAddConverter.h" + +#include "Dialect/IR/TFLNodes.h" + +#include "GraphBlock.h" + +#include <loco.h> +#include <loco/Service/ShapeInference.h> + +#include <cassert> + +namespace +{ + +inline void init_fused_act_func(locoex::TFLAdd *node) +{ + node->fusedActivationFunction(locoex::FusedActFunc::NONE); +} + +} // namespace + +namespace exo +{ + +/** + * @brief Converts loco::FeatureBiasAdd to locoex::TFLAdd + * + * Before: + * Foo ---+ + * | + * loco::FeatureBiasAdd - FeatureDecode - ... + * | + * Bar - BiasEncode --+ + * + * After: + * + * Foo - loco::FeatureDecode --+ loco::FeatureBiasAdd + * |(x) + * TFLAdd -- loco::FeatureEncode - FeatureDecode - ... + * |(y) + * Bar - BiasEncode - loco::BiasDecode --+ + */ +bool FeatureBiasAddConverter::convert(loco::FeatureBiasAdd *origin) +{ + auto *graph = origin->graph(); + + auto tfl_add = graph->nodes()->create<locoex::TFLAdd>(); + + // handling input x + assert(loco::shape_get(origin->value()).domain() == loco::Domain::Feature); + + auto fea_dec = make_feature_decode<FeatureLayout::NHWC>(origin->value()); + tfl_add->x(fea_dec); + + // handling input y + auto bias_dec = graph->nodes()->create<loco::BiasDecode>(); + assert(bias_dec != nullptr); + + bias_dec->input(origin->bias()); + + tfl_add->y(bias_dec); + + // fused activation function + init_fused_act_func(tfl_add); + + // handling output + auto fea_enc = make_feature_encode<FeatureLayout::NHWC>(tfl_add); + + loco::replace(origin).with(fea_enc); + origin->value(nullptr); + + return true; +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/FeatureBiasAddConverter.h b/compiler/exo/src/Conversion/FeatureBiasAddConverter.h new file mode 100644 index 000000000..5c4f10213 --- /dev/null +++ b/compiler/exo/src/Conversion/FeatureBiasAddConverter.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_FEATUREBIASADD_CONVERTER__ +#define __CONVERSION_FEATUREBIASADD_CONVERTER__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +class FeatureBiasAddConverter : public CanonicalNodeConverter<loco::FeatureBiasAdd> +{ +public: + const char *name(void) const final { return "exo::TFLAddConverter"; } + +public: + bool convert(loco::FeatureBiasAdd *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_FEATUREBIASADD_CONVERTER__ diff --git a/compiler/exo/src/Conversion/FeatureBiasAddConverter.test.cpp b/compiler/exo/src/Conversion/FeatureBiasAddConverter.test.cpp new file mode 100644 index 000000000..f3c4a5f81 --- /dev/null +++ b/compiler/exo/src/Conversion/FeatureBiasAddConverter.test.cpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FeatureBiasAddConverter.h" + +#include "GraphBlock.h" +#include "Dialect/IR/TFLNodes.h" + +#include "TestGraph.h" +#include "TestHelper.h" + +#include <loco.h> + +#include <gtest/gtest.h> + +TEST(FeatureBiasAddConverterTest, basic_test) +{ + exo::test::ExampleGraph<exo::test::ExampleGraphType::FeatureBiasAdd> g; + + { // attrib setting + // pull + g.pull->dtype(loco::DataType::FLOAT32); + g.pull->shape({1, 2, 2, 3}); + + // bias value + g.constgen->dtype(loco::DataType::FLOAT32); + g.constgen->shape({3}); + g.constgen->size<loco::DataType::FLOAT32>(3); + + g.constgen->at<loco::DataType::FLOAT32>(0) = 0.5; + g.constgen->at<loco::DataType::FLOAT32>(1) = 1; + g.constgen->at<loco::DataType::FLOAT32>(2) = 1.5; + } + + EXO_TEST_ASSERT_NODE_COUNT({g.push}, 7); // sanity check + + // let's convert!! + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FeatureBiasAddConverter>(); + + test_phase.run(g.graph()); + + /* + Expected: + + Pull - FeatureEncoder - FeatureDecode - TFLAdd - FeatureEncode - FeatureDecode - Push + | + ConstGen - BiasEncode - BiasDecode ---+ + */ + } + + // check surroundings + auto tfl_add = exo::test::find_first_node_bytype<locoex::TFLAdd>(g.graph()); + { + ASSERT_TRUE(tfl_add != nullptr); + + // input x and its pred + { + auto actual_fea_dec = dynamic_cast<loco::FeatureDecode *>(tfl_add->x()); + ASSERT_TRUE(actual_fea_dec != nullptr); + + auto actual_fea_enc = dynamic_cast<loco::FeatureEncode *>(actual_fea_dec->input()); + ASSERT_TRUE(actual_fea_enc != nullptr); + ASSERT_TRUE(actual_fea_enc == g.fea_enc); + } + + // input y and its pred + { + auto actual_bias_dec = dynamic_cast<loco::BiasDecode *>(tfl_add->y()); + ASSERT_TRUE(actual_bias_dec != nullptr); + + auto actual_bias_enc = dynamic_cast<loco::BiasEncode *>(actual_bias_dec->input()); + ASSERT_TRUE(actual_bias_enc != nullptr); + ASSERT_TRUE(actual_bias_enc == g.bias_enc); + } + + // output check + { + auto actual_fea_enc = exo::test::get_only_succ<loco::FeatureEncode>(tfl_add); + ASSERT_TRUE(actual_fea_enc != nullptr); + + auto actual_fea_dec = exo::test::get_only_succ<loco::FeatureDecode>(actual_fea_enc); + ASSERT_TRUE(actual_fea_dec != nullptr); + ASSERT_TRUE(actual_fea_dec == g.fea_dec); + } + } +} diff --git a/compiler/exo/src/Conversion/MatMulConverter.cpp b/compiler/exo/src/Conversion/MatMulConverter.cpp new file mode 100644 index 000000000..b1158b73d --- /dev/null +++ b/compiler/exo/src/Conversion/MatMulConverter.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MatMulConverter.h" + +#include "Dialect/IR/TFLNodes.h" + +#include "GraphBlock.h" +#include "Check.h" + +#include <loco.h> +#include <loco/Service/TypeInference.h> +#include <loco/Service/ShapeInference.h> + +namespace exo +{ +/** + * @brief Converts loco::MatMul to locoex::TFLFullyConnected + * @note Because TFLFullyConnected accepts input and weights of loco::Domain::Matrix, + * loco::MatrixDecode will be inserted as an input and weights + * to meet domain invariant. + * + * How it works: + * + * Before: + * Foo1 ---- MatrixEncode ---- MatMul ---- MatrixDecode ---- Bar + * Foo2 ---- MatrixEncode ----/ + * + * After: + * + * Foo1 - MatrixEncode - MatrixDecode - TFLFullyConnected - MatrixEncode - MatrixDecode - Bar + * Foo2 - MatrixEncode - MatrixDecode -/ + * + * @note This method replaces MatMul with "- MatrixDecode - TFLFullyConnected - MatrixEncode -". + * - MatrixDecode -/ + * Redundant nodes will be removed during transforms. + * + * @ref + * https://github.com/tensorflow/tensorflow/blob/v1.13.1/tensorflow/lite/kernels/internal/reference/fully_connected.h + */ +bool MatMulConverter::convert(loco::MatMul *origin) +{ + auto *graph = origin->graph(); + + assert(origin->lhs()); + assert(origin->rhs()); + + auto tfl_fc = graph->nodes()->create<locoex::TFLFullyConnected>(); + tfl_fc->fusedActivationFunction(locoex::FusedActFunc::NONE); + + // let's create a new graph connection with tfl_fc + { + // input + auto lhs_matrix_dec = make_matrix_decode<MatrixLayout::HW>(origin->lhs()); + tfl_fc->input(lhs_matrix_dec); + + // weights (WH format on TFLite) + auto rhs_matrix_dec = make_matrix_decode<MatrixLayout::WH>(origin->rhs()); + tfl_fc->weights(rhs_matrix_dec); + + // bias + auto zero_const = graph->nodes()->create<locoex::TFLConst>(); + { // TODO Create optimization pass which fuse additional Add into bias of Conv or FC + assert(loco::shape_known(origin)); + assert(loco::dtype_known(origin) && loco::dtype_get(origin) == loco::DataType::FLOAT32); + + auto output_depth = loco::shape_get(origin->rhs()).as<loco::MatrixShape>().width(); + // TODO Fix it with type inference + zero_const->dtype(loco::DataType::FLOAT32); + zero_const->rank(1); + zero_const->dim(0) = output_depth; + zero_const->size<loco::DataType::FLOAT32>(output_depth.value()); + for (uint32_t x = 0; x < output_depth.value(); x++) + zero_const->at<loco::DataType::FLOAT32>(x) = 0.0; + } + tfl_fc->bias(zero_const); + + // output + auto matrix_enc = make_matrix_encode<MatrixLayout::HW>(tfl_fc); + + // replace canonical node + loco::replace(origin).with(matrix_enc); + origin->lhs(nullptr); + origin->rhs(nullptr); + } + + return true; +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/MatMulConverter.h b/compiler/exo/src/Conversion/MatMulConverter.h new file mode 100644 index 000000000..e64c4a0f2 --- /dev/null +++ b/compiler/exo/src/Conversion/MatMulConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_FULLY_CONNECTED_CONVERTER__ +#define __CONVERSION_FULLY_CONNECTED_CONVERTER__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::MatMul to locoex::TFLFullyConnected + */ +class MatMulConverter : public CanonicalNodeConverter<loco::MatMul> +{ +public: + const char *name(void) const final { return "exo::MatMulConverter"; } + +public: + bool convert(loco::MatMul *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_FULLY_CONNECTED_CONVERTER__ diff --git a/compiler/exo/src/Conversion/MaxPool2DConverter.cpp b/compiler/exo/src/Conversion/MaxPool2DConverter.cpp new file mode 100644 index 000000000..67e5ab833 --- /dev/null +++ b/compiler/exo/src/Conversion/MaxPool2DConverter.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MaxPool2DConverter.h" + +#include "Dialect/IR/TFLNodes.h" +#include "GraphBlock.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Converts loco::MaxPool2D to locoex::TFLMaxPool2D + * + * @note This works similar to AvgPool2DConverter. Please refer to the comment in + * AvgPool2DConverter. + */ +bool MaxPool2DConverter::convert(loco::MaxPool2D *origin) +{ + auto *graph = origin->graph(); + + auto dec = make_feature_decode<FeatureLayout::NHWC>(origin->ifm()); + auto tfl_max = graph->nodes()->create<locoex::TFLMaxPool2D>(); + { + tfl_max->value(dec); + + // set attributes + tfl_max->stride()->w(origin->stride()->horizontal()); + tfl_max->stride()->h(origin->stride()->vertical()); + + tfl_max->filter()->w(origin->window()->horizontal()); + tfl_max->filter()->h(origin->window()->vertical()); + + auto pad = origin->pad(); + if (pad->bottom() == 0 && pad->top() == 0 && pad->left() == 0 && pad->right() == 0) + tfl_max->padding(locoex::Padding::VALID); + else + // TODO This is necessary, but not sufficient condition. More rigorous check required + tfl_max->padding(locoex::Padding::SAME); + + tfl_max->fusedActivationFunction(locoex::FusedActFunc::NONE); + } + + auto enc = make_feature_encode<FeatureLayout::NHWC>(tfl_max); + + loco::replace(origin).with(enc); + origin->ifm(nullptr); + + return true; +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/MaxPool2DConverter.h b/compiler/exo/src/Conversion/MaxPool2DConverter.h new file mode 100644 index 000000000..3f526d88f --- /dev/null +++ b/compiler/exo/src/Conversion/MaxPool2DConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_MAXPOOL2D_CONVERTER__ +#define __CONVERSION_MAXPOOL2D_CONVERTER__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::MaxPool2D to locoex::TFLMaxPool2D + */ +class MaxPool2DConverter : public CanonicalNodeConverter<loco::MaxPool2D> +{ +public: + const char *name(void) const final { return "exo::MaxPool2DConverter"; } + +public: + bool convert(loco::MaxPool2D *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_MAXPOOL2D_CONVERTER__ diff --git a/compiler/exo/src/Conversion/Relu6Converter.cpp b/compiler/exo/src/Conversion/Relu6Converter.cpp new file mode 100644 index 000000000..b694511f5 --- /dev/null +++ b/compiler/exo/src/Conversion/Relu6Converter.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Relu6Converter.h" + +#include "GraphBlock.h" +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" + +#include <loco/Service/ShapeInference.h> + +namespace +{ + +class Relu6InputHandler : public exo::InputHandler<loco::ReLU6, locoex::TFLRelu6> +{ +public: + void handover(loco::ReLU6 *origin, locoex::TFLRelu6 *replacer) override + { + replacer->features(origin->input()); + } + + std::vector<loco::Node *> getInputsToConvert(loco::ReLU6 *origin) override + { + std::vector<loco::Node *> inputs({origin->input()}); + return inputs; + } + + void set(locoex::TFLRelu6 *replacer, std::vector<loco::Node *> &to) override + { + assert(to.size() == 1); + + replacer->features(to.at(0)); + } + + void nullify(loco::ReLU6 *origin) override { origin->input(nullptr); } +}; + +} // namespace + +namespace exo +{ + +bool Relu6Converter::convert(loco::ReLU6 *origin) +{ + Relu6InputHandler input_handler; + exo::DomainConverter<loco::ReLU6, locoex::TFLRelu6> domain_converter; + + auto tfl_node = domain_converter.convert<FeatureLayout::NHWC>(origin, input_handler); + + return (tfl_node != nullptr); +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/Relu6Converter.h b/compiler/exo/src/Conversion/Relu6Converter.h new file mode 100644 index 000000000..d987b42d0 --- /dev/null +++ b/compiler/exo/src/Conversion/Relu6Converter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_RELU6_CONVERTER_H__ +#define __CONVERSION_RELU6_CONVERTER_H__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::Relu6 to TFLRelu6 + */ +class Relu6Converter : public CanonicalNodeConverter<loco::ReLU6> +{ +public: + const char *name(void) const final { return "exo::Relu6Converter"; } + +public: + bool convert(loco::ReLU6 *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_RELU6_CONVERTER_H__ diff --git a/compiler/exo/src/Conversion/ReluConverter.cpp b/compiler/exo/src/Conversion/ReluConverter.cpp new file mode 100644 index 000000000..92adef94d --- /dev/null +++ b/compiler/exo/src/Conversion/ReluConverter.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ReluConverter.h" + +#include "GraphBlock.h" +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" + +#include <loco/Service/ShapeInference.h> + +namespace +{ + +class ReluInputHandler : public exo::InputHandler<loco::ReLU, locoex::TFLRelu> +{ +public: + void handover(loco::ReLU *origin, locoex::TFLRelu *replacer) override + { + replacer->features(origin->input()); + } + + std::vector<loco::Node *> getInputsToConvert(loco::ReLU *origin) override + { + std::vector<loco::Node *> inputs({origin->input()}); + return inputs; + } + + void set(locoex::TFLRelu *replacer, std::vector<loco::Node *> &to) override + { + assert(to.size() == 1); + + replacer->features(to.at(0)); + } + + void nullify(loco::ReLU *origin) override { origin->input(nullptr); } +}; + +} // namespace + +namespace exo +{ + +bool ReluConverter::convert(loco::ReLU *origin) +{ + ReluInputHandler input_handler; + exo::DomainConverter<loco::ReLU, locoex::TFLRelu> domain_converter; + + auto tfl_node = domain_converter.convert<FeatureLayout::NHWC>(origin, input_handler); + + return (tfl_node != nullptr); +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/ReluConverter.h b/compiler/exo/src/Conversion/ReluConverter.h new file mode 100644 index 000000000..e1e82ae4b --- /dev/null +++ b/compiler/exo/src/Conversion/ReluConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_RELU_CONVERTER_H__ +#define __CONVERSION_RELU_CONVERTER_H__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::Relu to TFLRelu + */ +class ReluConverter : public CanonicalNodeConverter<loco::ReLU> +{ +public: + const char *name(void) const final { return "exo::ReluConverter"; } + +public: + bool convert(loco::ReLU *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_RELU_CONVERTER_H__ diff --git a/compiler/exo/src/Conversion/ReluConverter.test.cpp b/compiler/exo/src/Conversion/ReluConverter.test.cpp new file mode 100644 index 000000000..f53d656b4 --- /dev/null +++ b/compiler/exo/src/Conversion/ReluConverter.test.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ReluConverter.h" + +#include "GraphBlock.h" +#include "Dialect/IR/TFLNodes.h" + +#include "TestHelper.h" +#include "TestGraph.h" + +#include <gtest/gtest.h> + +TEST(ReluConverterTest, relu_tensor_inout) +{ + exo::test::TestGraph graph; + { + auto tanh = graph.append<loco::Tanh>(graph.pull); + auto relu = graph.append<loco::ReLU>(tanh); + auto relu6 = graph.append<loco::ReLU6>(relu); + graph.complete(); + + auto pull = graph.pull; + { + pull->dtype(loco::DataType::FLOAT32); + pull->shape({2, 2}); + } + } + + // let's convert + exo::test::TypeShapeReadyPhase test_phase; + { + test_phase.add_pass<exo::ReluConverter>(); + test_phase.run(graph.g.get()); + } + + loco::Node *node = exo::test::find_first_node_bytype<loco::Tanh>(graph.g.get()); + ASSERT_TRUE(node != nullptr); + node = exo::test::get_only_succ<locoex::TFLRelu>(node); + ASSERT_TRUE(node != nullptr); + node = exo::test::get_only_succ<loco::ReLU6>(node); + ASSERT_TRUE(node != nullptr); +} + +TEST(ReluConverterTest, relu_feature_inout) +{ + // g = Pull - FeatureEncode - Relu - FeatureDecode - Push + exo::test::TestGraph graph; + { + auto enc = exo::make_feature_encode<exo::FeatureLayout::NHWC>(graph.pull); + auto relu = graph.append<loco::ReLU>(enc); + auto dec = exo::make_feature_decode<exo::FeatureLayout::NHWC>(relu); + graph.complete(dec); + } + + auto pull = graph.pull; + { + pull->dtype(loco::DataType::FLOAT32); + pull->shape({1, 2, 3, 4}); + } + + exo::test::TypeShapeReadyPhase test_phase; + { + test_phase.add_pass<exo::ReluConverter>(); + test_phase.run(graph.g.get()); + } + + // now, g = Pull - FeatureEncode - FeatureDecode - TFLRelu - FeatureEncode - FeatureDecode - Push + + // Check + EXO_TEST_ASSERT_NODE_COUNT({graph.push}, 7); + + // Check [FeatureEncode - FeatureDecode - TFLRelu - FeatureEncode - FeatureDecode] chunk + loco::Node *node = exo::test::find_first_node_bytype<loco::FeatureEncode>(graph.g.get()); + ASSERT_TRUE(node != nullptr); + node = exo::test::get_only_succ<loco::FeatureDecode>(node); + ASSERT_TRUE(node != nullptr); + node = exo::test::get_only_succ<locoex::TFLRelu>(node); + ASSERT_TRUE(node != nullptr); + node = exo::test::get_only_succ<loco::FeatureEncode>(node); + ASSERT_TRUE(node != nullptr); + node = exo::test::get_only_succ<loco::FeatureDecode>(node); + ASSERT_TRUE(node != nullptr); +} diff --git a/compiler/exo/src/Conversion/TensorBroadcastConverter.cpp b/compiler/exo/src/Conversion/TensorBroadcastConverter.cpp new file mode 100644 index 000000000..532332742 --- /dev/null +++ b/compiler/exo/src/Conversion/TensorBroadcastConverter.cpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TensorBroadcastConverter.h" + +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +#include <loco.h> +#include <loco/IR/CanonicalDialect.h> +#include <loco/IR/CanonicalNode.h> + +#include <set> + +namespace +{ + +template <class T> loco::TensorBroadcast *input_as_tbc(T *node) +{ + loco::TensorBroadcast *tbc = dynamic_cast<loco::TensorBroadcast *>(node->x()); + if (tbc == nullptr) + tbc = dynamic_cast<loco::TensorBroadcast *>(node->y()); + + return tbc; +} + +struct Collector final : public locoex::TFLNodeMutableVisitor<void> +{ + using NodePair = std::pair<loco::TensorBroadcast *, loco::Node *>; + + void visit(locoex::TFLAdd *node) final + { + if (auto tbc = input_as_tbc<locoex::TFLAdd>(node)) + { + NodePair pair(tbc, node); + candidates.insert(pair); + } + } + + void visit(locoex::TFLDiv *node) final + { + if (auto tbc = input_as_tbc<locoex::TFLDiv>(node)) + { + NodePair pair(tbc, node); + candidates.insert(pair); + } + } + + void visit(locoex::TFLMul *node) final + { + if (auto tbc = input_as_tbc<locoex::TFLMul>(node)) + { + NodePair pair(tbc, node); + candidates.insert(pair); + } + } + + void visit(locoex::TFLSub *node) final + { + if (auto tbc = input_as_tbc<locoex::TFLSub>(node)) + { + NodePair pair(tbc, node); + candidates.insert(pair); + } + } + + void visit(locoex::TFLMaximum *node) final + { + if (auto tbc = input_as_tbc<locoex::TFLMaximum>(node)) + { + NodePair pair(tbc, node); + candidates.insert(pair); + } + } + + void visit(locoex::TFLNode *) final { return; } + + std::set<NodePair> candidates; +}; + +bool mapping_condition(Collector::NodePair &) +{ + // TODO fill condition + + return true; +} + +template <class T> void jump_connection(loco::TensorBroadcast *tbc, T *tflnode) +{ + if (tflnode->x() == tbc) + tflnode->x(tbc->input()); + else if (tflnode->y() == tbc) + tflnode->y(tbc->input()); + else + assert(false); + + tbc->input(nullptr); +} + +} // namespace + +namespace exo +{ + +/** + * @brief Disconnects loco::TensorBroadcast from the graph if following node + * is one of binary node: TFLAdd, TFLSub, TFLMul, TFLDiv, TFLMaximum + * and meets condition (TBA) + * @note + * Before: + * x --- TensorBroadcast --- TFLXXX --- output + * y ----------------------/ + * + * After: + * --- TensorBroadcast --- + * x --- TFLXXX --- output + * y --/ + */ +bool TensorBroadcastConverter::run(loco::Graph *graph) +{ + Collector collector; + + auto active_nodes = loco::active_nodes(loco::output_nodes(graph)); + + for (auto node : active_nodes) + { + if (node->dialect() == locoex::TFLDialect::get()) + { + auto tfl_node = dynamic_cast<locoex::TFLNode *>(node); + tfl_node->accept(&collector); + } + } + + bool changed = false; + + for (auto pair : collector.candidates) + { + if (mapping_condition(pair)) + { + loco::TensorBroadcast *tensorbroadcast = pair.first; + if (auto tfladd = dynamic_cast<locoex::TFLAdd *>(pair.second)) + { + jump_connection<locoex::TFLAdd>(tensorbroadcast, tfladd); + changed = true; + } + else if (auto tfldiv = dynamic_cast<locoex::TFLDiv *>(pair.second)) + { + jump_connection<locoex::TFLDiv>(tensorbroadcast, tfldiv); + changed = true; + } + else if (auto tflmul = dynamic_cast<locoex::TFLMul *>(pair.second)) + { + jump_connection<locoex::TFLMul>(tensorbroadcast, tflmul); + changed = true; + } + else if (auto tflsub = dynamic_cast<locoex::TFLSub *>(pair.second)) + { + jump_connection<locoex::TFLSub>(tensorbroadcast, tflsub); + changed = true; + } + else if (auto tflmaximum = dynamic_cast<locoex::TFLMaximum *>(pair.second)) + { + jump_connection<locoex::TFLMaximum>(tensorbroadcast, tflmaximum); + changed = true; + } + else + { + assert(false); + } + } + } + + return changed; +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/TensorBroadcastConverter.h b/compiler/exo/src/Conversion/TensorBroadcastConverter.h new file mode 100644 index 000000000..3cf79b0ba --- /dev/null +++ b/compiler/exo/src/Conversion/TensorBroadcastConverter.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TENSOR_BROADCAST_CONVERTER_H__ +#define __TENSOR_BROADCAST_CONVERTER_H__ + +#include <loco.h> +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Pass to resolve TensorBroadcast IR + */ +class TensorBroadcastConverter : public logo::Pass +{ +public: + virtual const char *name(void) const { return "exo::TensorBroadcastConverter"; } + +public: + bool run(loco::Graph *graph); +}; + +} // namespace exo + +#endif //__TENSOR_BROADCAST_CONVERTER_H__ diff --git a/compiler/exo/src/Conversion/TensorConcatConverter.cpp b/compiler/exo/src/Conversion/TensorConcatConverter.cpp new file mode 100644 index 000000000..1c36b11f8 --- /dev/null +++ b/compiler/exo/src/Conversion/TensorConcatConverter.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TensorConcatConverter.h" + +#include "GraphBlock.h" +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" + +#include <loco/Service/ShapeInference.h> + +namespace exo +{ +/** + * @brief Converts loco::TensorConcat to locoex::TFLConcatenate + * + * Before: + * input:0 ----- loco::TensorConcat ------- C + * input:1 ----/ + * + * After: + * input:0 ----- locoex::TFLConcatenate --- C + * input:1 ----/ + * + * input:0 ----- loco::TensorConcat --- + * input:1 ----/ + * + */ +bool TensorConcatConverter::convert(loco::TensorConcat *origin) +{ + assert(loco::shape_get(origin).domain() == loco::Domain::Tensor); + + if (!loco::shape_known(origin)) + { + return false; + } + + auto tfl_concat = origin->graph()->nodes()->create<locoex::TFLConcatenation>(2); + tfl_concat->values(0, origin->lhs()); + tfl_concat->values(1, origin->rhs()); + tfl_concat->axis(origin->axis()); + tfl_concat->fusedActivationFunction(locoex::FusedActFunc::NONE); + + loco::replace(origin).with(tfl_concat); + + origin->lhs(nullptr); + origin->rhs(nullptr); + + return true; +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/TensorConcatConverter.h b/compiler/exo/src/Conversion/TensorConcatConverter.h new file mode 100644 index 000000000..6b90f4731 --- /dev/null +++ b/compiler/exo/src/Conversion/TensorConcatConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_TENSORCONCAT_CONVERTER_H__ +#define __CONVERSION_TENSORCONCAT_CONVERTER_H__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::TensorConcat to TFLConcatenate + */ +class TensorConcatConverter : public CanonicalNodeConverter<loco::TensorConcat> +{ +public: + const char *name(void) const final { return "exo::TensorConcatConverter"; } + +public: + bool convert(loco::TensorConcat *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_TENSORCONCAT_CONVERTER_H__ diff --git a/compiler/exo/src/Conversion/TensorReduceConverter.cpp b/compiler/exo/src/Conversion/TensorReduceConverter.cpp new file mode 100644 index 000000000..8fcb1682d --- /dev/null +++ b/compiler/exo/src/Conversion/TensorReduceConverter.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TensorReduceConverter.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Check.h" + +#include <oops/InternalExn.h> + +#include <loco.h> +#include <loco/Service/ShapeInference.h> + +namespace +{ + +/** + * @brief Convert given TensorReduce as TFLMean + * + * <Before> + * In --- loco::TensorReduce --- Out(s) + * + * <After> + * In -------- locoex::TFLMean --- Out(s) + * / + * TFLConst --- + * (reduction indices) + */ +bool convert_as_mean(loco::TensorReduce *origin) +{ + EXO_ASSERT(origin->func() == loco::ReduceFunc::Mean, "func should be Mean for this helper"); + EXO_ASSERT(origin->input(), "TensorReduce has no input"); + + auto *graph = origin->graph(); + + // Make reduction indicies TFLConst node + auto reduction = graph->nodes()->create<locoex::TFLConst>(); + { + auto input_rank = loco::shape_get(origin->input()).as<loco::TensorShape>().rank(); + + std::vector<int32_t> red_vec; + for (uint32_t axis = 0; axis < input_rank; ++axis) + if (origin->axes()->defined(axis)) + red_vec.push_back(static_cast<int32_t>(axis)); + + const loco::DataType S32 = loco::DataType::S32; + + reduction->dtype(S32); + reduction->rank(1); + reduction->dim(0) = red_vec.size(); + reduction->size<S32>(red_vec.size()); + for (uint32_t i = 0; i < red_vec.size(); ++i) + reduction->at<S32>(i) = red_vec.at(i); + } + + // Make TFLMean node to replace + auto mean = graph->nodes()->create<locoex::TFLMean>(); + mean->input(origin->input()); + mean->reduction_indices(reduction); + mean->keep_dims(true); // Canonical TensorReduce always keep dimensions + + // replace canonical node + loco::replace(origin).with(mean); + origin->input(nullptr); + + return true; +} + +} // namespace + +namespace exo +{ + +bool TensorReduceConverter::convert(loco::TensorReduce *origin) +{ + if (origin->func() == loco::ReduceFunc::Mean) + return convert_as_mean(origin); + else + INTERNAL_EXN_V("Unsupported ReduceFunc", oops::to_uint32(origin->func())); +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/TensorReduceConverter.h b/compiler/exo/src/Conversion/TensorReduceConverter.h new file mode 100644 index 000000000..dfd65ad2d --- /dev/null +++ b/compiler/exo/src/Conversion/TensorReduceConverter.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TENSOR_REDUCE_CONVERTER__ +#define __TENSOR_REDUCE_CONVERTER__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::TensorReduce to appropriate TFL reduce operation + * @note loco::TensorReduce always keep dimensions + * + * Currently support: + * - When loco::TensorReduce::func() == Mean, convert to TFLMean + TFLConst + * - TODO Support other cases + */ +class TensorReduceConverter : public CanonicalNodeConverter<loco::TensorReduce> +{ +public: + const char *name(void) const final { return "exo::TensorReduceConverter"; } + +public: + bool convert(loco::TensorReduce *origin) final; +}; + +} // namespace exo + +#endif // __TENSOR_REDUCE_CONVERTER__ diff --git a/compiler/exo/src/Conversion/TensorTransposeConverter.cpp b/compiler/exo/src/Conversion/TensorTransposeConverter.cpp new file mode 100644 index 000000000..25c27fe7e --- /dev/null +++ b/compiler/exo/src/Conversion/TensorTransposeConverter.cpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TensorTransposeConverter.h" + +#include "Dialect/IR/TFLNodes.h" + +#include <loco.h> +#include <loco/Service/ShapeInference.h> + +#include <oops/InternalExn.h> + +#include <algorithm> +#include <cassert> +#include <vector> + +namespace +{ + +void validate_perm(loco::TensorTranspose *origin) +{ + // check perm values are correct + std::vector<uint32_t> base_perms; // such as {0, 1, 2, 3, ... } + std::vector<uint32_t> perms; // perm values in TensorTranspose + + base_perms.resize(origin->perm()->size()); + perms.resize(origin->perm()->size()); + for (loco::TensorAxis x = 0; x < origin->perm()->size(); x++) + { + base_perms[x] = x; + perms[x] = origin->perm()->axis(x); + } + + if (!std::is_permutation(base_perms.begin(), base_perms.end(), perms.begin())) + INTERNAL_EXN("wrong perm value"); +} + +} // namespace + +namespace exo +{ +/** + * @brief Converts loco::TensorTranspose to locoex::TFLTranspose + */ +bool TensorTransposeConverter::convert(loco::TensorTranspose *origin) +{ + auto *graph = origin->graph(); + + auto tfl_transpose = graph->nodes()->create<locoex::TFLTranspose>(); + { + // validation + { + assert(origin->input() != nullptr); + + auto input_rank = loco::shape_get(origin->input()).as<loco::TensorShape>().rank(); + if (input_rank != origin->perm()->size()) + INTERNAL_EXN_V("perm size should be same with input rank", + oops::to_uint32(origin->perm()->size())); + + validate_perm(origin); + } + + tfl_transpose->a(origin->input()); + + // perm : set TFLConst + auto perm_const = graph->nodes()->create<locoex::TFLConst>(); + { + perm_const->dtype(loco::DataType::S32); + perm_const->rank(1); + perm_const->dim(0) = origin->perm()->size(); + perm_const->size<loco::DataType::S32>(origin->perm()->size()); + + // add perm values into perm TFLConst + for (loco::TensorAxis x = 0; x < origin->perm()->size(); x++) + { + perm_const->at<loco::DataType::S32>(x) = origin->perm()->axis(x); + } + } + tfl_transpose->perm(perm_const); + } + + // replace canonical node + loco::replace(origin).with(tfl_transpose); + origin->input(nullptr); + + return true; +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/TensorTransposeConverter.h b/compiler/exo/src/Conversion/TensorTransposeConverter.h new file mode 100644 index 000000000..9b61ff38d --- /dev/null +++ b/compiler/exo/src/Conversion/TensorTransposeConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_TENSORTRANSPOSE_CONVERTER__ +#define __CONVERSION_TENSORTRANSPOSE_CONVERTER__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::TensorTranspose to locoex::TFLTranspose + */ +class TensorTransposeConverter : public CanonicalNodeConverter<loco::TensorTranspose> +{ +public: + const char *name(void) const final { return "exo::TensorTransposeConverter"; } + +public: + bool convert(loco::TensorTranspose *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_TENSORTRANSPOSE_CONVERTER__ diff --git a/compiler/exo/src/Conversion/TransposedConv2DConverter.cpp b/compiler/exo/src/Conversion/TransposedConv2DConverter.cpp new file mode 100644 index 000000000..c03b64f48 --- /dev/null +++ b/compiler/exo/src/Conversion/TransposedConv2DConverter.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TransposedConv2DConverter.h" + +#include "Dialect/IR/TFLNodes.h" + +#include "GraphBlock.h" + +#include <loco.h> +#include <loco/Service/ShapeInference.h> + +namespace exo +{ + +bool TransposedConv2DConverter::convert(loco::TransposedConv2D *origin) +{ + // Shape is required to set origin->inputSizes() + if (not loco::shape_known(origin)) + return false; + + if ((origin->ifm() == nullptr) or (origin->ker() == nullptr)) + return false; + + auto *graph = origin->graph(); + + auto tfl_tr_conv = graph->nodes()->create<locoex::TFLTransposeConv>(); + { + tfl_tr_conv->stride()->w(origin->stride()->horizontal()); + tfl_tr_conv->stride()->h(origin->stride()->vertical()); + + auto pad = origin->pad(); + if (pad->bottom() == 0 && pad->top() == 0 && pad->left() == 0 && pad->right() == 0) + tfl_tr_conv->padding(locoex::Padding::VALID); + else + // TODO This is necessary, but not sufficient condition. More rigorous check required + tfl_tr_conv->padding(locoex::Padding::SAME); + } + + // let's create a new graph connection with tfl_tr_conv + { + // Make inputSizes from shape of origin + auto input_sizes_const = graph->nodes()->create<locoex::TFLConst>(); + auto origin_shape = loco::shape_get(origin).as<loco::FeatureShape>(); + + const loco::DataType S32 = loco::DataType::S32; + + input_sizes_const->dtype(S32); + input_sizes_const->rank(1); + input_sizes_const->dim(0) = 4; + input_sizes_const->size<S32>(4); + // Note that NHWC is layout for inputSizes determined by tflite format + input_sizes_const->at<S32>(0) = origin_shape.count().value(); // N + input_sizes_const->at<S32>(1) = origin_shape.height().value(); // H + input_sizes_const->at<S32>(2) = origin_shape.width().value(); // W + input_sizes_const->at<S32>(3) = origin_shape.depth().value(); // C + + tfl_tr_conv->inputSizes(input_sizes_const); + + // filter + auto filter_dec = make_filter_decode<FilterLayout::OHWI>(origin->ker()); + tfl_tr_conv->filter(filter_dec); + + // outBackprop + auto feature_dec = make_feature_decode<FeatureLayout::NHWC>(origin->ifm()); + tfl_tr_conv->outBackprop(feature_dec); + + // output + auto feature_enc = make_feature_encode<FeatureLayout::NHWC>(tfl_tr_conv); + + // replace canonical node + loco::replace(origin).with(feature_enc); + origin->ifm(nullptr); + } + + return true; +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/TransposedConv2DConverter.h b/compiler/exo/src/Conversion/TransposedConv2DConverter.h new file mode 100644 index 000000000..f51e0a5bc --- /dev/null +++ b/compiler/exo/src/Conversion/TransposedConv2DConverter.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_TRANSPOSEDCONV2D_CONVERTER__ +#define __CONVERSION_TRANSPOSEDCONV2D_CONVERTER__ + +#include "CanonicalNodeConverter.h" + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Convert loco::TransposedConv2D to locoex::TFLTransposeConv and auxiliary + * + * + * <BEFORE> + * + * IFM ------- TransposedConv2D --- OFM + * (Feature) / (Feature) + * / + * KER ------ + * (Filter) + * + * + * <AFTER> + * + * out_backprop : IFM ------- FeatureDecode --- TFLTransposeConv --- FeatureEncode --- OFM + * [Feature] [Tensor] / / [Tensor] [Feature] + * / / + * filter: KER ------- FilterDecode --- / + * [Filter] [Tensor] / + * / + * input_sizes : TFLConst (new) ------------ + * [Tensor] + */ +class TransposedConv2DConverter : public CanonicalNodeConverter<loco::TransposedConv2D> +{ +public: + const char *name(void) const final { return "exo::TransposedConv2DConverter"; } + +public: + bool convert(loco::TransposedConv2D *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_TRANSPOSEDCONV2D_CONVERTER__ diff --git a/compiler/exo/src/Conversions.h b/compiler/exo/src/Conversions.h new file mode 100644 index 000000000..8eb4ed2e4 --- /dev/null +++ b/compiler/exo/src/Conversions.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSIONS_H__ +#define __CONVERSIONS_H__ + +#include "Conversion/AvgPool2DConverter.h" +#include "Conversion/ConstGenConverter.h" +#include "Conversion/Conv2DConverter.h" +#include "Conversion/DepthwiseConv2DConverter.h" +// TODO loco::DepthwiseFilterEncode +#include "Conversion/EltwiseAddConverter.h" +#include "Conversion/EltwiseDivConverter.h" +#include "Conversion/EltwiseMaxConverter.h" +#include "Conversion/EltwiseMulConverter.h" +#include "Conversion/EltwiseSqrtConverter.h" +#include "Conversion/EltwiseSubConverter.h" +#include "Conversion/FeatureBiasAddConverter.h" +// TODO loco::FixedReshape +#include "Conversion/MatMulConverter.h" +#include "Conversion/MaxPool2DConverter.h" +#include "Conversion/ReluConverter.h" +#include "Conversion/Relu6Converter.h" +// TODO loco::Tanh +#include "Conversion/TensorConcatConverter.h" +// TODO loco::TensorBiasAdd +#include "Conversion/TensorBroadcastConverter.h" +#include "Conversion/TensorReduceConverter.h" +// TODO loco::TensorSoftmax +#include "Conversion/TensorTransposeConverter.h" +#include "Conversion/TransposedConv2DConverter.h" + +#endif // __CONVERSIONS_H__ diff --git a/compiler/exo/src/Convert.cpp b/compiler/exo/src/Convert.cpp new file mode 100644 index 000000000..45f0481f4 --- /dev/null +++ b/compiler/exo/src/Convert.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Convert.h" + +#include "Conversions.h" +#include "Pass/ShapeInferencePass.h" +#include "Pass/TypeInferencePass.h" +#include "ProgressReporter.h" +#include "Knob.h" + +#include <loco.h> +#include <loco/Service/ShapeInference.h> +#include <loco/Service/CanonicalShapeInferenceRule.h> +#include <loco/Service/TypeInference.h> + +#include <logo/SimplifyDomainConversionPass.h> +#include <logo/RemoveDeadNodePass.h> +#include <logo/RemoveForwardNodePass.h> + +#include <logo/Phase.h> +#include <stdex/Memory.h> + +namespace exo +{ + +void convert_to_TFLNodes(loco::Graph *graph) +{ + // run Shape and Type inference must be run before conversion + loco::CanonicalShapeInferenceRule shape_rule; + loco::apply(&shape_rule).to(graph); + + loco::CanonicalTypeInferenceRule type_rule; + loco::apply(&type_rule).to(graph); + + logo::Phase phase; + { + // prepare type and shape before conversion + phase.emplace_back(stdex::make_unique<TypeInferencePass>()); + phase.emplace_back(stdex::make_unique<ShapeInferencePass>()); + + // Add converters for canonical nodes. Note: Not all loco canonical nodes are listed. + phase.emplace_back(stdex::make_unique<AvgPool2DConverter>()); + phase.emplace_back(stdex::make_unique<ConstGenConverter>()); + phase.emplace_back(stdex::make_unique<Conv2DConverter>()); + phase.emplace_back(stdex::make_unique<DepthwiseConv2DConverter>()); + // TODO loco::DepthwiseFilterEncode + phase.emplace_back(stdex::make_unique<EltwiseAddConverter>()); + phase.emplace_back(stdex::make_unique<EltwiseDivConverter>()); + phase.emplace_back(stdex::make_unique<EltwiseMaxConverter>()); + phase.emplace_back(stdex::make_unique<EltwiseMulConverter>()); + phase.emplace_back(stdex::make_unique<EltwiseSqrtConverter>()); + phase.emplace_back(stdex::make_unique<EltwiseSubConverter>()); + phase.emplace_back(stdex::make_unique<FeatureBiasAddConverter>()); + // TODO loco::FixedReshape + phase.emplace_back(stdex::make_unique<MatMulConverter>()); + phase.emplace_back(stdex::make_unique<MaxPool2DConverter>()); + phase.emplace_back(stdex::make_unique<ReluConverter>()); + phase.emplace_back(stdex::make_unique<Relu6Converter>()); + // TODO loco::Tanh + phase.emplace_back(stdex::make_unique<TensorConcatConverter>()); + // TODO loco::TensorBiasAdd + phase.emplace_back(stdex::make_unique<TensorBroadcastConverter>()); + phase.emplace_back(stdex::make_unique<TensorReduceConverter>()); + // TODO loco::TensorSoftmax + phase.emplace_back(stdex::make_unique<TensorTransposeConverter>()); + phase.emplace_back(stdex::make_unique<TransposedConv2DConverter>()); + + // Add optimization below + phase.emplace_back(stdex::make_unique<logo::SimplifyDomainConversionPass>()); + phase.emplace_back(stdex::make_unique<logo::RemoveForwardNodePass>()); + phase.emplace_back(stdex::make_unique<logo::RemoveDeadNodePass>()); + } + + logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{graph}; + + ProgressReporter prog(graph, logo::PhaseStrategy::Restart); + phase_runner.attach(&prog); + phase_runner.run(phase); + + // TODO Assert if all canonical nodes are converted to TFL node +} + +} // namespace exo diff --git a/compiler/exo/src/Convert.h b/compiler/exo/src/Convert.h new file mode 100644 index 000000000..7038f9cf7 --- /dev/null +++ b/compiler/exo/src/Convert.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERT_H__ +#define __CONVERT_H__ + +#include <loco.h> + +namespace exo +{ + +void convert_to_TFLNodes(loco::Graph *graph); + +} // namespace exo + +#endif // __CONVERT_H__ diff --git a/compiler/exo/src/Dialect/IR/CircleDialect.cpp b/compiler/exo/src/Dialect/IR/CircleDialect.cpp new file mode 100644 index 000000000..ecd43b0a3 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleDialect.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleDialect.h" + +namespace locoex +{ + +loco::Dialect *CircleDialect::get(void) +{ + static CircleDialect d; + return &d; +} + +} // namespace locoex diff --git a/compiler/exo/src/Dialect/IR/CircleDialect.h b/compiler/exo/src/Dialect/IR/CircleDialect.h new file mode 100644 index 000000000..9857d9e6d --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleDialect.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_CIRCLEDIALECT_H__ +#define __LOCOEX_IR_CIRCLEDIALECT_H__ + +#include <loco/IR/Dialect.h> + +namespace locoex +{ + +class CircleDialect final : public loco::Dialect +{ +private: + CircleDialect() = default; + +public: + CircleDialect(const CircleDialect &) = delete; + CircleDialect(CircleDialect &&) = delete; + +public: + static loco::Dialect *get(void); +}; + +} // namespace locoex + +#endif // __LOCOEX_IR_CIRCLEDIALECT_H__ diff --git a/compiler/exo/src/Dialect/IR/CircleDialect.test.cpp b/compiler/exo/src/Dialect/IR/CircleDialect.test.cpp new file mode 100644 index 000000000..6132eb361 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleDialect.test.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleDialect.h" + +#include <gtest/gtest.h> + +TEST(CircleDialectTest, get) +{ + using locoex::CircleDialect; + + auto d = CircleDialect::get(); + + // get() SHOULD return a valid(non-null) pointer + ASSERT_NE(d, nullptr); + // The return value SHOULD be stable across multiple invocations + ASSERT_EQ(d, CircleDialect::get()); +} diff --git a/compiler/exo/src/Dialect/IR/CircleNode.cpp b/compiler/exo/src/Dialect/IR/CircleNode.cpp new file mode 100644 index 000000000..cdcd434ea --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleNode.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleNode.h" + +#include "CircleDialect.h" + +namespace locoex +{ + +const loco::Dialect *CircleNode::dialect(void) const { return CircleDialect::get(); } + +} // namespace locoex diff --git a/compiler/exo/src/Dialect/IR/CircleNode.h b/compiler/exo/src/Dialect/IR/CircleNode.h new file mode 100644 index 000000000..1ae9d38bd --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleNode.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_CIRCLENODE_H__ +#define __LOCOEX_IR_CIRCLENODE_H__ + +#include "CircleNodeDecl.h" +#include "CircleNodeImpl.h" + +#endif // __LOCOEX_IR_CIRCLENODE_H__ diff --git a/compiler/exo/src/Dialect/IR/CircleNodeDecl.h b/compiler/exo/src/Dialect/IR/CircleNodeDecl.h new file mode 100644 index 000000000..358b1f0ce --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleNodeDecl.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_CIRCLENODEDECL_H__ +#define __LOCOEX_IR_CIRCLENODEDECL_H__ + +#include <loco/IR/Node.h> +#include <loco/IR/Dialect.h> + +#include "CircleOpcode.h" +#include "CircleNodeVisitor.forward.h" + +namespace locoex +{ + +struct CircleNode : public loco::Node +{ + virtual ~CircleNode() = default; + + const loco::Dialect *dialect(void) const final; + virtual CircleOpcode opcode(void) const = 0; + + template <typename T> T accept(CircleNodeVisitorBase<T> *) const; + template <typename T> T accept(CircleNodeMutableVisitorBase<T> *); +}; + +template <CircleOpcode Code> struct CircleNodeImpl : public CircleNode +{ + virtual ~CircleNodeImpl() = default; + + uint32_t opnum(void) const final { return static_cast<uint32_t>(Code); } + CircleOpcode opcode(void) const final { return Code; } +}; + +} // namespace locoex + +#endif // __LOCOEX_IR_CIRCLENODEDECL_H__ diff --git a/compiler/exo/src/Dialect/IR/CircleNodeImpl.h b/compiler/exo/src/Dialect/IR/CircleNodeImpl.h new file mode 100644 index 000000000..d9f487111 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleNodeImpl.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_CIRCLENODEIMPL_H__ +#define __LOCOEX_IR_CIRCLENODEIMPL_H__ + +#include "CircleNodes.h" +#include "CircleNodeVisitor.h" + +#include <oops/InternalExn.h> + +#include <cassert> + +namespace locoex +{ + +template <typename T> T CircleNode::accept(CircleNodeVisitorBase<T> *v) const +{ + switch (this->opcode()) + { +#define CIRCLE_NODE(OPCODE, CLASS) \ + \ + case CircleOpcode::OPCODE: \ + return v->visit(dynamic_cast<const CLASS *>(this)); + +#include "CircleNodes.lst" +#undef CIRCLE_NODE + + default: + break; + } + + INTERNAL_EXN("CircleNode::accept(CircleNodeVisitorBase) not handled"); +} + +template <typename T> T CircleNode::accept(CircleNodeMutableVisitorBase<T> *v) +{ + switch (this->opcode()) + { +#define CIRCLE_NODE(OPCODE, CLASS) \ + \ + case CircleOpcode::OPCODE: \ + return v->visit(dynamic_cast<CLASS *>(this)); + +#include "CircleNodes.lst" +#undef CIRCLE_NODE + + default: + break; + } + + INTERNAL_EXN("CircleNode::accept(CircleNodeMutableVisitorBase) not handled"); +} + +} // namespace locoex + +#endif // __LOCOEX_IR_CIRCLENODEIMPL_H__ diff --git a/compiler/exo/src/Dialect/IR/CircleNodeVisitor.forward.h b/compiler/exo/src/Dialect/IR/CircleNodeVisitor.forward.h new file mode 100644 index 000000000..8ae28abf3 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleNodeVisitor.forward.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_CIRCLENODE_VISITOR_FORWARD_H__ +#define __LOCOEX_IR_CIRCLENODE_VISITOR_FORWARD_H__ + +namespace locoex +{ + +// NOTE These forward declarations SHOULD BE aligned with Node delcarations in +// "CircleNodeVisitor.h" +template <typename T> struct CircleNodeVisitorBase; +template <typename T> struct CircleNodeMutableVisitorBase; + +} // namespace locoex + +#endif // __LOCOEX_IR_CIRCLENODE_VISITOR_FORWARD_H__ diff --git a/compiler/exo/src/Dialect/IR/CircleNodeVisitor.h b/compiler/exo/src/Dialect/IR/CircleNodeVisitor.h new file mode 100644 index 000000000..fc70c9ebc --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleNodeVisitor.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_CIRCLENODE_VISITOR_H__ +#define __LOCOEX_IR_CIRCLENODE_VISITOR_H__ + +#include "CircleNode.h" +#include "CircleNodes.h" + +#include <oops/InternalExn.h> + +namespace locoex +{ + +/** + * DO NOT use this class. Use CircleNodeVisitor instead. + */ +template <typename T> struct CircleNodeVisitorBase +{ + virtual ~CircleNodeVisitorBase() = default; + +#define CIRCLE_NODE(OPCODE, Circle_CLASS) virtual T visit(const Circle_CLASS *) = 0; + +#include "CircleNodes.lst" +#undef CIRCLE_NODE +}; + +template <typename T> struct CircleNodeVisitor : public CircleNodeVisitorBase<T> +{ + virtual ~CircleNodeVisitor() = default; + +#define CIRCLE_NODE(OPCODE, Circle_CLASS) \ + \ + virtual T visit(const Circle_CLASS *node) { return visit(static_cast<const CircleNode *>(node)); } + +#include "CircleNodes.lst" +#undef CIRCLE_NODE + + /// @brief Default fallback + virtual T visit(const CircleNode *) { INTERNAL_EXN("CircleNodeVisistor: NYI node"); } +}; + +/** + * DO NOT use this class. Use CircleNodeMutableVisitor instead. + */ +template <typename T> struct CircleNodeMutableVisitorBase +{ + virtual ~CircleNodeMutableVisitorBase() = default; + +#define CIRCLE_NODE(OPCODE, Circle_CLASS) virtual T visit(Circle_CLASS *) = 0; + +#include "CircleNodes.lst" +#undef CIRCLE_NODE +}; + +template <typename T> struct CircleNodeMutableVisitor : public CircleNodeMutableVisitorBase<T> +{ + virtual ~CircleNodeMutableVisitor() = default; + +#define CIRCLE_NODE(OPCODE, Circle_CLASS) \ + \ + virtual T visit(Circle_CLASS *node) { return visit(static_cast<CircleNode *>(node)); } + +#include "CircleNodes.lst" +#undef CIRCLE_NODE + + /// @brief Default fallback + virtual T visit(CircleNode *) { INTERNAL_EXN("CircleMutableNodeVisistor: NYI node"); } +}; + +} // namespace locoex + +#endif // __LOCOEX_IR_CIRCLENODE_VISITOR_H__ diff --git a/compiler/exo/src/Dialect/IR/CircleNodes.cpp b/compiler/exo/src/Dialect/IR/CircleNodes.cpp new file mode 100644 index 000000000..bba59ff4d --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleNodes.cpp @@ -0,0 +1,18 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This is to validate CircleNodes.h +#include "CircleNodes.h" diff --git a/compiler/exo/src/Dialect/IR/CircleNodes.h b/compiler/exo/src/Dialect/IR/CircleNodes.h new file mode 100644 index 000000000..7be093103 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleNodes.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_CIRCLENODES_H__ +#define __LOCOEX_IR_CIRCLENODES_H__ + +#include "CircleNodeDecl.h" +#include "CircleOpcode.h" + +#include "FusedActFunc.h" +#include "NodeMixins.h" // FixedArityNode + +#include <loco/IR/Node.h> + +namespace locoex +{ + +/// @brief enumeration of mixin class +enum class CircleNodeTrait +{ + FusedActFunc, +}; + +template <CircleNodeTrait T> class CircleNodeMixin; + +template <> class CircleNodeMixin<CircleNodeTrait::FusedActFunc> +{ +public: + CircleNodeMixin() = default; + +public: + FusedActFunc fusedActivationFunction() const { return _fused_act_fun; } + void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; } + +private: + FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED; +}; + +/** + * @brief INSTANCE_NORM in circle + */ +class CircleInstanceNorm final + : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::INSTANCE_NORM>>, + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> +{ +public: + /// @note Currently only support FLOAT32 as input node + loco::Node *input(void) const { return at(0)->node(); } + void input(loco::Node *node) { at(0)->node(node); } + + loco::Node *gamma(void) const { return at(1)->node(); } + void gamma(loco::Node *node) { at(1)->node(node); } + + loco::Node *beta(void) const { return at(2)->node(); } + void beta(loco::Node *node) { at(2)->node(node); } + + float epsilon() const { return _epsilon; } + void epsilon(float epsilon) { _epsilon = epsilon; } + +private: + float _epsilon = 1e-05; +}; + +} // namespace locoex + +#endif // __LOCOEX_IR_CIRCLENODES_H__ diff --git a/compiler/exo/src/Dialect/IR/CircleNodes.lst b/compiler/exo/src/Dialect/IR/CircleNodes.lst new file mode 100644 index 000000000..96baf2917 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleNodes.lst @@ -0,0 +1,8 @@ +#ifndef CIRCLE_NODE +#error "Define CIRCLE_NODE" +#endif // CIRCLE_NODE + +// +// PLEASE SORT NODE DECLS IN ALPHABETICAL ORDER +// +CIRCLE_NODE(INSTANCE_NORM, locoex::CircleInstanceNorm) diff --git a/compiler/exo/src/Dialect/IR/CircleNodes.test.cpp b/compiler/exo/src/Dialect/IR/CircleNodes.test.cpp new file mode 100644 index 000000000..b63e7ccae --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleNodes.test.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleNodes.h" + +#include "CircleDialect.h" +#include "CircleOpcode.h" + +#include <gtest/gtest.h> + +TEST(CircleInstanceNormTest, constructor) +{ + locoex::CircleInstanceNorm instance_norm; + + ASSERT_EQ(instance_norm.dialect(), locoex::CircleDialect::get()); + ASSERT_EQ(instance_norm.opcode(), locoex::CircleOpcode::INSTANCE_NORM); + + ASSERT_EQ(instance_norm.input(), nullptr); + ASSERT_EQ(instance_norm.gamma(), nullptr); + ASSERT_EQ(instance_norm.beta(), nullptr); + ASSERT_FLOAT_EQ(instance_norm.epsilon(), 1e-05); + ASSERT_EQ(instance_norm.fusedActivationFunction(), locoex::FusedActFunc::UNDEFINED); +} diff --git a/compiler/exo/src/Dialect/IR/CircleOpcode.h b/compiler/exo/src/Dialect/IR/CircleOpcode.h new file mode 100644 index 000000000..264304049 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/CircleOpcode.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_CIRCLEOPCODE_H__ +#define __LOCOEX_IR_CIRCLEOPCODE_H__ + +namespace locoex +{ + +enum class CircleOpcode +{ +#define CIRCLE_NODE(OPCODE, CLASS) OPCODE, +#include "CircleNodes.lst" +#undef CIRCLE_NODE +}; + +} // namespace locoex + +#endif // __LOCOEX_IR_CIRCLEOPCODE_H__ diff --git a/compiler/exo/src/Dialect/IR/FusedActFunc.h b/compiler/exo/src/Dialect/IR/FusedActFunc.h new file mode 100644 index 000000000..b73a0799e --- /dev/null +++ b/compiler/exo/src/Dialect/IR/FusedActFunc.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __DIALECT_IR_FUSEDACTFUNC_H__ +#define __DIALECT_IR_FUSEDACTFUNC_H__ + +namespace locoex +{ + +// TODO Divide into TFL version and Circle version when they go different approach +enum class FusedActFunc +{ + UNDEFINED, // This is not defined by TFLite or Circle. This was added to + // prevent programming error. + NONE, + RELU, + RELU6 +}; + +} // namespace locoex + +#endif // __DIALECT_IR_FUSEDACTFUNC_H__ diff --git a/compiler/exo/src/Dialect/IR/NodeMixins.cpp b/compiler/exo/src/Dialect/IR/NodeMixins.cpp new file mode 100644 index 000000000..cdfe0d8d1 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/NodeMixins.cpp @@ -0,0 +1,18 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This is to validate NodeMixins.h +#include "NodeMixins.h" diff --git a/compiler/exo/src/Dialect/IR/NodeMixins.h b/compiler/exo/src/Dialect/IR/NodeMixins.h new file mode 100644 index 000000000..c35daebc6 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/NodeMixins.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __DIALECT_IR_NODEMIXINS_H__ +#define __DIALECT_IR_NODEMIXINS_H__ + +#include <loco/IR/Node.h> + +namespace locoex +{ + +/** + * @brief Nodes with the fixed number of inputs + * + * TODO Deprecated this class, and use loco::FixedArity instead + */ +template <unsigned N, typename Base> class FixedArityNode : public Base +{ +public: + FixedArityNode() + { + for (uint32_t n = 0; n < N; ++n) + { + _args[n] = std::unique_ptr<loco::Use>(new loco::Use{this}); + } + } + + virtual ~FixedArityNode() = default; + +public: + unsigned arity(void) const final { return N; } + + loco::Node *arg(uint32_t n) const final { return _args.at(n)->node(); } + + void drop(void) final + { + for (uint32_t n = 0; n < N; ++n) + { + _args.at(n)->node(nullptr); + } + } + +protected: + // This API allows inherited classes to access "_args" field. + loco::Use *at(unsigned n) const { return _args.at(n).get(); } + +private: + std::array<std::unique_ptr<loco::Use>, N> _args; +}; + +} // namespace locoex + +#endif // __DIALECT_IR_NODEMIXINS_H__ diff --git a/compiler/exo/src/Dialect/IR/TFLDialect.cpp b/compiler/exo/src/Dialect/IR/TFLDialect.cpp new file mode 100644 index 000000000..8cbf9a364 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLDialect.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLDialect.h" + +namespace locoex +{ + +loco::Dialect *TFLDialect::get(void) +{ + static TFLDialect d; + return &d; +} + +} // namespace locoex diff --git a/compiler/exo/src/Dialect/IR/TFLDialect.h b/compiler/exo/src/Dialect/IR/TFLDialect.h new file mode 100644 index 000000000..96463a9f9 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLDialect.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_TFLDIALECT_H__ +#define __LOCOEX_IR_TFLDIALECT_H__ + +#include <loco/IR/Dialect.h> + +namespace locoex +{ + +class TFLDialect final : public loco::Dialect +{ +private: + TFLDialect() = default; + +public: + TFLDialect(const TFLDialect &) = delete; + TFLDialect(TFLDialect &&) = delete; + +public: + static loco::Dialect *get(void); +}; + +} // namespace locoex + +#endif // __LOCOEX_IR_TFLDIALECT_H__ diff --git a/compiler/exo/src/Dialect/IR/TFLDialect.test.cpp b/compiler/exo/src/Dialect/IR/TFLDialect.test.cpp new file mode 100644 index 000000000..136721e2d --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLDialect.test.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLDialect.h" + +#include <gtest/gtest.h> + +TEST(TFLDialectTest, get) +{ + using locoex::TFLDialect; + + auto d = TFLDialect::get(); + + // get() SHOULD return a valid(non-null) pointer + ASSERT_NE(d, nullptr); + // The return value SHOULD be stable across multiple invocations + ASSERT_EQ(d, TFLDialect::get()); +} diff --git a/compiler/exo/src/Dialect/IR/TFLNode.cpp b/compiler/exo/src/Dialect/IR/TFLNode.cpp new file mode 100644 index 000000000..82d5f1eba --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLNode.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLNode.h" + +#include "TFLDialect.h" + +namespace locoex +{ + +const loco::Dialect *TFLNode::dialect(void) const { return TFLDialect::get(); } + +} // namespace locoex diff --git a/compiler/exo/src/Dialect/IR/TFLNode.h b/compiler/exo/src/Dialect/IR/TFLNode.h new file mode 100644 index 000000000..eff69b1a5 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLNode.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_TFLNODE_H__ +#define __LOCOEX_IR_TFLNODE_H__ + +#include "TFLNodeDecl.h" +#include "TFLNodeImpl.h" + +#endif // __LOCOEX_IR_TFLNODE_H__ diff --git a/compiler/exo/src/Dialect/IR/TFLNodeDecl.h b/compiler/exo/src/Dialect/IR/TFLNodeDecl.h new file mode 100644 index 000000000..d13900ab3 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLNodeDecl.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_TFLNODEDECL_H__ +#define __LOCOEX_IR_TFLNODEDECL_H__ + +#include <loco/IR/Node.h> +#include <loco/IR/Dialect.h> + +#include "TFLOpcode.h" +#include "TFLNodeVisitor.forward.h" + +namespace locoex +{ + +struct TFLNode : public loco::Node +{ + virtual ~TFLNode() = default; + + const loco::Dialect *dialect(void) const final; + virtual TFLOpcode opcode(void) const = 0; + + template <typename T> T accept(TFLNodeVisitorBase<T> *) const; + template <typename T> T accept(TFLNodeMutableVisitorBase<T> *); +}; + +template <TFLOpcode Code> struct TFLNodeImpl : public TFLNode +{ + virtual ~TFLNodeImpl() = default; + + uint32_t opnum(void) const final { return static_cast<uint32_t>(Code); } + TFLOpcode opcode(void) const final { return Code; } +}; + +} // namespace locoex + +#endif // __LOCOEX_IR_TFLNODEDECL_H__ diff --git a/compiler/exo/src/Dialect/IR/TFLNodeImpl.h b/compiler/exo/src/Dialect/IR/TFLNodeImpl.h new file mode 100644 index 000000000..63388279a --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLNodeImpl.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_TFLNODEIMPL_H__ +#define __LOCOEX_IR_TFLNODEIMPL_H__ + +#include "TFLNodes.h" +#include "TFLNodeVisitor.h" + +#include <oops/InternalExn.h> + +#include <cassert> + +namespace locoex +{ + +template <typename T> T TFLNode::accept(TFLNodeVisitorBase<T> *v) const +{ + switch (this->opcode()) + { +#define TFL_NODE(OPCODE, CLASS) \ + \ + case TFLOpcode::OPCODE: \ + return v->visit(dynamic_cast<const CLASS *>(this)); + +#include "TFLNodes.lst" +#undef TFL_NODE + + default: + break; + } + + INTERNAL_EXN("TFLNode::accept(TFLNodeVisitorBase) not handled"); +} + +template <typename T> T TFLNode::accept(TFLNodeMutableVisitorBase<T> *v) +{ + switch (this->opcode()) + { +#define TFL_NODE(OPCODE, CLASS) \ + \ + case TFLOpcode::OPCODE: \ + return v->visit(dynamic_cast<CLASS *>(this)); + +#include "TFLNodes.lst" +#undef TFL_NODE + + default: + break; + } + + INTERNAL_EXN("TFLNode::accept(TFLNodeMutableVisitorBase) not handled"); +} + +} // namespace locoex + +#endif // __LOCOEX_IR_TFLNODEIMPL_H__ diff --git a/compiler/exo/src/Dialect/IR/TFLNodeVisitor.forward.h b/compiler/exo/src/Dialect/IR/TFLNodeVisitor.forward.h new file mode 100644 index 000000000..e98057bc3 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLNodeVisitor.forward.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_TFLNODE_VISITOR_FORWARD_H__ +#define __LOCOEX_IR_TFLNODE_VISITOR_FORWARD_H__ + +namespace locoex +{ + +// NOTE These forward declarations SHOULD BE aligned with Node delcarations in +// "TFLNodeVisitor.h" +template <typename T> struct TFLNodeVisitorBase; +template <typename T> struct TFLNodeMutableVisitorBase; + +} // namespace locoex + +#endif // __LOCOEX_IR_TFLNODE_VISITOR_FORWARD_H__ diff --git a/compiler/exo/src/Dialect/IR/TFLNodeVisitor.h b/compiler/exo/src/Dialect/IR/TFLNodeVisitor.h new file mode 100644 index 000000000..e1f5959c0 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLNodeVisitor.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_TFLNODE_VISITOR_H__ +#define __LOCOEX_IR_TFLNODE_VISITOR_H__ + +#include "TFLNode.h" +#include "TFLNodes.h" + +#include <oops/InternalExn.h> + +namespace locoex +{ + +/** + * DO NOT use this class. Use TFLNodeVisitor instead. + */ +template <typename T> struct TFLNodeVisitorBase +{ + virtual ~TFLNodeVisitorBase() = default; + +#define TFL_NODE(OPCODE, TFL_CLASS) virtual T visit(const TFL_CLASS *) = 0; + +#include "TFLNodes.lst" +#undef TFL_NODE +}; + +template <typename T> struct TFLNodeVisitor : public TFLNodeVisitorBase<T> +{ + virtual ~TFLNodeVisitor() = default; + +#define TFL_NODE(OPCODE, TFL_CLASS) \ + \ + virtual T visit(const TFL_CLASS *node) { return visit(static_cast<const TFLNode *>(node)); } + +#include "TFLNodes.lst" +#undef TFL_NODE + + /// @brief Default fallback + virtual T visit(const TFLNode *) { INTERNAL_EXN("TFLNodeVisitor: NYI node"); } +}; + +/** + * DO NOT use this class. Use TFLNodeMutableVisitor instead. + */ +template <typename T> struct TFLNodeMutableVisitorBase +{ + virtual ~TFLNodeMutableVisitorBase() = default; + +#define TFL_NODE(OPCODE, TFL_CLASS) virtual T visit(TFL_CLASS *) = 0; + +#include "TFLNodes.lst" +#undef TFL_NODE +}; + +template <typename T> struct TFLNodeMutableVisitor : public TFLNodeMutableVisitorBase<T> +{ + virtual ~TFLNodeMutableVisitor() = default; + +#define TFL_NODE(OPCODE, TFL_CLASS) \ + \ + virtual T visit(TFL_CLASS *node) { return visit(static_cast<TFLNode *>(node)); } + +#include "TFLNodes.lst" +#undef TFL_NODE + + /// @brief Default fallback + virtual T visit(TFLNode *) { INTERNAL_EXN("TFLNodeMutableVisitor: NYI node"); } +}; + +} // namespace locoex + +#endif // __LOCOEX_IR_TFLNODE_VISITOR_H__ diff --git a/compiler/exo/src/Dialect/IR/TFLNodes.cpp b/compiler/exo/src/Dialect/IR/TFLNodes.cpp new file mode 100644 index 000000000..f385ce0d9 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLNodes.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLNodes.h" + +#include "Check.h" + +#include <loco.h> + +#include <cassert> + +namespace locoex +{ + +template <loco::DataType DT> uint32_t TFLConst::size(void) const +{ + assert(dtype() == DT); + assert(_data.size() % sizeof(typename loco::DataTypeImpl<DT>::Type) == 0); + return _data.size() / sizeof(typename loco::DataTypeImpl<DT>::Type); +} + +template <loco::DataType DT> void TFLConst::size(uint32_t l) +{ + assert(dtype() == DT); + _data.resize(l * sizeof(typename loco::DataTypeImpl<DT>::Type)); +} + +template <loco::DataType DT> +const typename loco::DataTypeImpl<DT>::Type &TFLConst::at(uint32_t n) const +{ + assert(dtype() == DT); + assert(n < size<DT>()); + return *(reinterpret_cast<const typename loco::DataTypeImpl<DT>::Type *>(_data.data()) + n); +} + +template <loco::DataType DT> typename loco::DataTypeImpl<DT>::Type &TFLConst::at(uint32_t n) +{ + assert(dtype() == DT); + assert(n < size<DT>()); + return *(reinterpret_cast<typename loco::DataTypeImpl<DT>::Type *>(_data.data()) + n); +} + +#define INSTANTIATE(DT) \ + template uint32_t TFLConst::size<DT>(void) const; \ + template void TFLConst::size<DT>(uint32_t); \ + template const typename loco::DataTypeImpl<DT>::Type &TFLConst::at<DT>(uint32_t) const; \ + template typename loco::DataTypeImpl<DT>::Type &TFLConst::at<DT>(uint32_t); + +INSTANTIATE(loco::DataType::S32); +INSTANTIATE(loco::DataType::FLOAT32); + +#undef INSTANTIATE + +void set_new_shape(locoex::TFLReshape *node, int32_t *base, uint32_t size) +{ + // Check node does not have both of new shape infos + EXO_ASSERT(node->shape() == nullptr, "node already has shape input"); + EXO_ASSERT(node->newShape()->rank() == 0, "node already has newShape attribute"); + + const loco::DataType S32 = loco::DataType::S32; + + // Set 2nd input as TFLConst + auto const_shape_node = node->graph()->nodes()->create<locoex::TFLConst>(); + const_shape_node->rank(1); + const_shape_node->dim(0) = size; + const_shape_node->dtype(S32); + const_shape_node->size<S32>(size); + for (uint32_t axis = 0; axis < size; ++axis) + const_shape_node->at<S32>(axis) = base[axis]; + node->shape(const_shape_node); + + // Set newShape attribute + node->newShape()->rank(size); + for (uint32_t axis = 0; axis < size; ++axis) + node->newShape()->dim(axis) = base[axis]; +} + +} // namespace locoex diff --git a/compiler/exo/src/Dialect/IR/TFLNodes.h b/compiler/exo/src/Dialect/IR/TFLNodes.h new file mode 100644 index 000000000..5f521a0a6 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLNodes.h @@ -0,0 +1,551 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_TFLNODES_H__ +#define __LOCOEX_IR_TFLNODES_H__ + +#include "TFLNodeDecl.h" +#include "TFLOpcode.h" + +#include "FusedActFunc.h" +#include "NodeMixins.h" + +#include <loco/IR/Node.h> +#include <loco/IR/NodeMixins.h> +#include <loco/IR/DataTypeTraits.h> + +#include <locoex/VariadicArityNode.h> + +#include <array> + +namespace locoex +{ + +enum class Padding +{ + UNDEFINED, // This is not defined by TFLite. This was added to prevent programming error. + SAME, + VALID, +}; + +class Filter final +{ +public: + Filter() : _w(1), _h(1) {} + + int32_t w() const { return _w; } + void w(int32_t w) { _w = w; } + + int32_t h() const { return _h; } + void h(int32_t h) { _h = h; } + +private: + int32_t _w; + int32_t _h; +}; + +class Stride final +{ +public: + Stride() : _w(1), _h(1) {} + + int32_t w() const { return _w; } + void w(int32_t w) { _w = w; } + + int32_t h() const { return _h; } + void h(int32_t h) { _h = h; } + +private: + int32_t _w; + int32_t _h; +}; + +/// @brief enumeration of mixin class +enum class TFLNodeTrait +{ + FusedActFunc, + Bias +}; + +template <TFLNodeTrait T> class TFLNodeMixin; + +template <> class TFLNodeMixin<TFLNodeTrait::FusedActFunc> +{ +public: + TFLNodeMixin() = default; + +public: + FusedActFunc fusedActivationFunction() const { return _fused_act_fun; } + void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; } + +private: + FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED; +}; + +/** + * @brief Mixin class for nodes that has a bias input + */ +template <> class TFLNodeMixin<TFLNodeTrait::Bias> +{ +public: + TFLNodeMixin() = default; + +public: + virtual loco::Node *bias(void) const = 0; /// @brief get the input for bias. + virtual void bias(loco::Node *node) = 0; /// @brief set the input for bias. +}; + +/** + * @brief ADD in TensorFlow Lite + */ +class TFLAdd final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::ADD>>, + public TFLNodeMixin<TFLNodeTrait::FusedActFunc> +{ +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; + +/** + * @brief AVERAGE_POOL_2D in TensorFlow Lite + */ +class TFLAveragePool2D final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::AVERAGE_POOL_2D>>, + public TFLNodeMixin<TFLNodeTrait::FusedActFunc> +{ +public: + TFLAveragePool2D() : _padding(Padding::UNDEFINED) { /* empty */} + +public: + loco::Node *value(void) const { return at(0)->node(); } + void value(loco::Node *node) { at(0)->node(node); } + + Padding padding() const { return _padding; } + void padding(Padding padding) { _padding = padding; } + + const Filter *filter(void) const { return &_filter; } + Filter *filter(void) { return &_filter; } + + const Stride *stride(void) const { return &_stride; } + Stride *stride(void) { return &_stride; } + +private: + Padding _padding; + Stride _stride; + Filter _filter; +}; + +/** + * @brief CONCATENATION in TensorFlow Lite + */ +class TFLConcatenation final : public VariadicArityNode<TFLNodeImpl<TFLOpcode::CONCATENATION>>, + public TFLNodeMixin<TFLNodeTrait::FusedActFunc> +{ +public: + TFLConcatenation(uint32_t arity) : VariadicArityNode<TFLNodeImpl<TFLOpcode::CONCATENATION>>(arity) + { + // TODO Support when arity is 0 + assert(arity >= 1); + } + +public: + uint32_t numValues(void) const { return arity(); } + +public: + Node *values(uint32_t index) const + { + assert(index < numValues()); + return at(index)->node(); + } + void values(uint32_t index, Node *node) + { + assert(index < numValues()); + at(index)->node(node); + } + +public: + uint32_t axis(void) const { return _axis; } + void axis(uint32_t axis) { _axis = axis; } + +private: + uint32_t _axis; +}; + +/** + * @brief Class to build tensor data + * @note This will not be exported as a specific op + */ +class TFLConst final : public FixedArityNode<0, TFLNodeImpl<TFLOpcode::CONST>>, + public loco::NodeMixin<loco::NodeTrait::DataType>, + public loco::NodeMixin<loco::NodeTrait::TensorShape> +{ +public: + TFLConst() = default; + +public: + template <loco::DataType DT> uint32_t size(void) const; + template <loco::DataType DT> void size(uint32_t size); + template <loco::DataType DT> const typename loco::DataTypeImpl<DT>::Type &at(uint32_t n) const; + template <loco::DataType DT> typename loco::DataTypeImpl<DT>::Type &at(uint32_t n); + +private: + std::vector<uint8_t> _data; +}; + +/** + * @brief CONV_2D in TensorFlow Lite + */ +class TFLConv2D final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::CONV_2D>>, + public TFLNodeMixin<TFLNodeTrait::FusedActFunc>, + public TFLNodeMixin<TFLNodeTrait::Bias> +{ +public: + loco::Node *input(void) const { return at(0)->node(); } + void input(loco::Node *node) { at(0)->node(node); } + + loco::Node *filter(void) const { return at(1)->node(); } + void filter(loco::Node *node) { at(1)->node(node); } + + loco::Node *bias(void) const override { return at(2)->node(); } + void bias(loco::Node *node) override { at(2)->node(node); } + +public: + Padding padding() const { return _padding; } + void padding(Padding padding) { _padding = padding; } + + const Stride *stride(void) const { return &_stride; } + Stride *stride(void) { return &_stride; } + +private: + Padding _padding = Padding::UNDEFINED; + Stride _stride; +}; + +/** + * @brief DEPTHWISE_CONV_2D in TensorFlow Lite + */ +class TFLDepthwiseConv2D final + : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::DEPTHWISE_CONV_2D>>, + public TFLNodeMixin<TFLNodeTrait::FusedActFunc>, + public TFLNodeMixin<TFLNodeTrait::Bias> +{ +public: + loco::Node *input(void) const { return at(0)->node(); } + void input(loco::Node *node) { at(0)->node(node); } + + loco::Node *filter(void) const { return at(1)->node(); } + void filter(loco::Node *node) { at(1)->node(node); } + + loco::Node *bias(void) const override { return at(2)->node(); } + void bias(loco::Node *node) override { at(2)->node(node); } + +public: + Padding padding() const { return _padding; } + void padding(Padding padding) { _padding = padding; } + + const Stride *stride(void) const { return &_stride; } + Stride *stride(void) { return &_stride; } + + int32_t depthMultiplier(void) const { return _depth_multiplier; } + void depthMultiplier(int32_t arg) { _depth_multiplier = arg; } + +private: + Padding _padding = Padding::UNDEFINED; + Stride _stride; + int32_t _depth_multiplier = 0; +}; + +/** + * @brief DIV in TensorFlow Lite + */ +class TFLDiv final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::DIV>>, + public TFLNodeMixin<TFLNodeTrait::FusedActFunc> +{ +public: + TFLDiv() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; + +/** + * @brief FULLY_CONNECTED in TensorFlow Lite + */ +class TFLFullyConnected final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::FULLY_CONNECTED>>, + public TFLNodeMixin<TFLNodeTrait::FusedActFunc>, + public TFLNodeMixin<TFLNodeTrait::Bias> +{ +public: + loco::Node *input(void) const { return at(0)->node(); } + void input(loco::Node *node) { at(0)->node(node); } + + loco::Node *weights(void) const { return at(1)->node(); } + void weights(loco::Node *node) { at(1)->node(node); } + + loco::Node *bias(void) const override { return at(2)->node(); } + void bias(loco::Node *node) override { at(2)->node(node); } +}; + +/** + * @brief MAXIMUM in TensorFlow Lite + */ +class TFLMaximum final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MAXIMUM>> +{ +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; + +/** + * @brief MAX_POOL_2D in TensorFlow Lite + */ +class TFLMaxPool2D final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::MAX_POOL_2D>>, + public TFLNodeMixin<TFLNodeTrait::FusedActFunc> +{ +public: + TFLMaxPool2D() : _padding(Padding::UNDEFINED) { /* empty */} + +public: + loco::Node *value(void) const { return at(0)->node(); } + void value(loco::Node *node) { at(0)->node(node); } + + Padding padding() const { return _padding; } + void padding(Padding padding) { _padding = padding; } + + const Filter *filter(void) const { return &_filter; } + Filter *filter(void) { return &_filter; } + + const Stride *stride(void) const { return &_stride; } + Stride *stride(void) { return &_stride; } + +private: + Padding _padding; + Stride _stride; + Filter _filter; +}; + +class TFLMean final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MEAN>> +{ +public: + loco::Node *input(void) const { return at(0)->node(); } + void input(loco::Node *node) { at(0)->node(node); } + + loco::Node *reduction_indices(void) const { return at(1)->node(); } + void reduction_indices(loco::Node *node) { at(1)->node(node); } + +public: + bool keep_dims(void) const { return _keep_dims; } + void keep_dims(bool keep_dims) { _keep_dims = keep_dims; } + +private: + bool _keep_dims = false; +}; + +/** + * @brief MUL in TensorFlow Lite + */ +class TFLMul final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MUL>>, + public TFLNodeMixin<TFLNodeTrait::FusedActFunc> +{ +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; + +class TFLRelu final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RELU>> +{ +public: + TFLRelu() = default; + +public: + loco::Node *features(void) const { return at(0)->node(); } + void features(loco::Node *node) { at(0)->node(node); } +}; + +class TFLRelu6 final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RELU6>> +{ +public: + TFLRelu6() = default; + +public: + loco::Node *features(void) const { return at(0)->node(); } + void features(loco::Node *node) { at(0)->node(node); } +}; + +class TFLReshape final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::RESHAPE>> +{ +public: + TFLReshape() = default; + +public: + loco::Node *tensor(void) const { return at(0)->node(); } + void tensor(loco::Node *node) { at(0)->node(node); } + + // TODO Make this input optional. That is, loco system does not emit error + // with this input being null + loco::Node *shape(void) const { return at(1)->node(); } + void shape(loco::Node *node) { at(1)->node(node); } + +public: + class Shape + { + public: + uint32_t rank(void) const { return _shape.size(); } + void rank(uint32_t rank) { _shape.resize(rank); } + + int32_t dim(uint32_t n) const { return _shape.at(n); } + int32_t &dim(uint32_t n) { return _shape.at(n); } + + private: + std::vector<int32_t> _shape; + }; + + const Shape *newShape(void) const { return &_new_shape; } + Shape *newShape(void) { return &_new_shape; } + +private: + Shape _new_shape; +}; + +/** + * @brief Set both TFLReshape's 2nd input as TFLConst, and newShape attribute + * with same value + * @note Shape inference for TFLReshape forces them to be same + * TODO find better place for this helper + */ +void set_new_shape(locoex::TFLReshape *node, int32_t *base, uint32_t size); + +class TFLRsqrt final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RSQRT>> +{ +public: + TFLRsqrt() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } +}; + +// TODO TFLSoftmax + +class TFLSqrt final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::SQRT>> +{ +public: + TFLSqrt() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } +}; + +class TFLSquaredDifference final + : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::SQUARED_DIFFERENCE>> +{ +public: + TFLSquaredDifference() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; + +/** + * @brief SUB in TensorFlow Lite + */ +class TFLSub final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::SUB>>, + public TFLNodeMixin<TFLNodeTrait::FusedActFunc> +{ +public: + TFLSub() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; + +// TODO TFLTanh + +/** + * @brief TRANSPOSE in TensorFlow Lite + */ +class TFLTranspose final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::TRANSPOSE>> +{ +public: + TFLTranspose() = default; + +public: + /// @brief Get the input node to transpose + loco::Node *a(void) const { return at(0)->node(); } + + /// @brief Set the input node to transpose + void a(loco::Node *node) { at(0)->node(node); } + + loco::Node *perm(void) const { return at(1)->node(); } + void perm(loco::Node *node) { at(1)->node(node); } +}; + +/** + * @brief TRANSPOSE_CONV in TensorFlow Lite + * + * @note Argument node function names are from TensorFlow. So refering 'in' and + * 'out' acutally means 'out' and 'in' of the this node. + */ +class TFLTransposeConv final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::TRANSPOSE_CONV>> +{ +public: + loco::Node *inputSizes(void) const { return at(0)->node(); } + void inputSizes(Node *node) { at(0)->node(node); } + + loco::Node *filter(void) const { return at(1)->node(); } + void filter(Node *node) { at(1)->node(node); } + + loco::Node *outBackprop(void) const { return at(2)->node(); } + void outBackprop(Node *node) { at(2)->node(node); } + +public: + const Padding &padding(void) const { return _padding; } + void padding(const Padding &padding) { _padding = padding; } + + const Stride *stride(void) const { return &_stride; } + Stride *stride(void) { return &_stride; } + +private: + Padding _padding; + Stride _stride; +}; + +// TODO define more children of TFLNode + +} // namespace locoex + +#endif // __LOCOEX_IR_TFLNODES_H__ diff --git a/compiler/exo/src/Dialect/IR/TFLNodes.lst b/compiler/exo/src/Dialect/IR/TFLNodes.lst new file mode 100644 index 000000000..225e2be3b --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLNodes.lst @@ -0,0 +1,30 @@ +#ifndef TFL_NODE +#error "Define TFL_NODE" +#endif // TFL_NODE + +// +// PLEASE SORT NODE DECLS IN ALPHABETICAL ORDER +// +TFL_NODE(ADD, locoex::TFLAdd) +TFL_NODE(AVERAGE_POOL_2D, locoex::TFLAveragePool2D) +TFL_NODE(CONCATENATION, locoex::TFLConcatenation) +TFL_NODE(CONST, locoex::TFLConst) +TFL_NODE(CONV_2D, locoex::TFLConv2D) +TFL_NODE(DEPTHWISE_CONV_2D, locoex::TFLDepthwiseConv2D) +TFL_NODE(DIV, locoex::TFLDiv) +TFL_NODE(FULLY_CONNECTED, locoex::TFLFullyConnected) +TFL_NODE(MAXIMUM, locoex::TFLMaximum) +TFL_NODE(MAX_POOL_2D, locoex::TFLMaxPool2D) +TFL_NODE(MEAN, locoex::TFLMean) +TFL_NODE(MUL, locoex::TFLMul) +TFL_NODE(RELU, locoex::TFLRelu) +TFL_NODE(RELU6, locoex::TFLRelu6) +TFL_NODE(RESHAPE, locoex::TFLReshape) +TFL_NODE(RSQRT, locoex::TFLRsqrt) +// TODO TFLSoftmax +TFL_NODE(SQRT, locoex::TFLSqrt) +TFL_NODE(SQUARED_DIFFERENCE, locoex::TFLSquaredDifference) +TFL_NODE(SUB, locoex::TFLSub) +// TODO TFLTanh +TFL_NODE(TRANSPOSE, locoex::TFLTranspose) +TFL_NODE(TRANSPOSE_CONV, locoex::TFLTransposeConv) diff --git a/compiler/exo/src/Dialect/IR/TFLNodes.test.cpp b/compiler/exo/src/Dialect/IR/TFLNodes.test.cpp new file mode 100644 index 000000000..09c5c83a0 --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLNodes.test.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLNodes.h" + +#include "TFLDialect.h" +#include "TFLOpcode.h" + +#include <gtest/gtest.h> + +TEST(TFLAddTest, constructor) +{ + locoex::TFLAdd add_node; + + ASSERT_EQ(add_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(add_node.opcode(), locoex::TFLOpcode::ADD); + + ASSERT_EQ(add_node.x(), nullptr); + ASSERT_EQ(add_node.y(), nullptr); +} + +// TODO TFLAveragePool2D + +TEST(TFLConcatTest, constructor) +{ + locoex::TFLConcatenation concat_node(3); + + ASSERT_EQ(concat_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(concat_node.opcode(), locoex::TFLOpcode::CONCATENATION); + + ASSERT_EQ(concat_node.numValues(), 3); + ASSERT_EQ(concat_node.values(0), nullptr); + ASSERT_EQ(concat_node.values(1), nullptr); + ASSERT_EQ(concat_node.values(2), nullptr); + ASSERT_EQ(concat_node.fusedActivationFunction(), locoex::FusedActFunc::UNDEFINED); +} + +// TODO TFLConv2D + +TEST(TFLDepthwiseConv2DTest, constructor) +{ + locoex::TFLDepthwiseConv2D dw_conv2d_node; + + ASSERT_EQ(dw_conv2d_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(dw_conv2d_node.opcode(), locoex::TFLOpcode::DEPTHWISE_CONV_2D); + + ASSERT_EQ(dw_conv2d_node.input(), nullptr); + ASSERT_EQ(dw_conv2d_node.filter(), nullptr); + ASSERT_EQ(dw_conv2d_node.bias(), nullptr); + ASSERT_EQ(dw_conv2d_node.padding(), locoex::Padding::UNDEFINED); + ASSERT_EQ(dw_conv2d_node.stride()->h(), 1); + ASSERT_EQ(dw_conv2d_node.stride()->w(), 1); + ASSERT_EQ(dw_conv2d_node.depthMultiplier(), 0); + ASSERT_EQ(dw_conv2d_node.fusedActivationFunction(), locoex::FusedActFunc::UNDEFINED); +} + +TEST(TFLDivTest, constructor) +{ + locoex::TFLDiv div_node; + + ASSERT_EQ(div_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(div_node.opcode(), locoex::TFLOpcode::DIV); + + ASSERT_EQ(div_node.x(), nullptr); + ASSERT_EQ(div_node.y(), nullptr); +} + +// TODO TFLMaxPool2D + +TEST(TFLMulTest, constructor) +{ + locoex::TFLMul mul_node; + + ASSERT_EQ(mul_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(mul_node.opcode(), locoex::TFLOpcode::MUL); + + ASSERT_EQ(mul_node.x(), nullptr); + ASSERT_EQ(mul_node.y(), nullptr); +} + +TEST(TFLReluTest, constructor) +{ + locoex::TFLRelu relu_node; + + ASSERT_EQ(relu_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(relu_node.opcode(), locoex::TFLOpcode::RELU); + + ASSERT_EQ(relu_node.features(), nullptr); +} + +// TODO TFLRelu6 + +TEST(TFLReshapeTest, constructor) +{ + locoex::TFLReshape reshape; + + ASSERT_EQ(reshape.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(reshape.opcode(), locoex::TFLOpcode::RESHAPE); + + ASSERT_EQ(reshape.tensor(), nullptr); + ASSERT_EQ(reshape.shape(), nullptr); + ASSERT_EQ(reshape.newShape()->rank(), 0); +} + +TEST(TFLReshapeTest, alloc_new_shape) +{ + locoex::TFLReshape reshape; + + reshape.newShape()->rank(2); + ASSERT_EQ(reshape.newShape()->rank(), 2); + + reshape.newShape()->dim(0) = 0; + reshape.newShape()->dim(1) = 1; + + auto &const_reshape = const_cast<const locoex::TFLReshape &>(reshape); + ASSERT_EQ(const_reshape.newShape()->dim(0), 0); + ASSERT_EQ(const_reshape.newShape()->dim(1), 1); +} + +// TODO TFLSoftmax + +// TODO TFLSqrt + +TEST(TFLSubTest, constructor) +{ + locoex::TFLSub sub_node; + + ASSERT_EQ(sub_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(sub_node.opcode(), locoex::TFLOpcode::SUB); + + ASSERT_EQ(sub_node.x(), nullptr); + ASSERT_EQ(sub_node.y(), nullptr); +} + +// TODO TFLTanh + +TEST(TFLTransposeTest, constructor) +{ + locoex::TFLTranspose tr_node; + + ASSERT_EQ(tr_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(tr_node.opcode(), locoex::TFLOpcode::TRANSPOSE); + + ASSERT_EQ(tr_node.a(), nullptr); + ASSERT_EQ(tr_node.perm(), nullptr); +} diff --git a/compiler/exo/src/Dialect/IR/TFLOpcode.h b/compiler/exo/src/Dialect/IR/TFLOpcode.h new file mode 100644 index 000000000..0c0ab64bd --- /dev/null +++ b/compiler/exo/src/Dialect/IR/TFLOpcode.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_IR_TFLOPCODE_H__ +#define __LOCOEX_IR_TFLOPCODE_H__ + +namespace locoex +{ + +enum class TFLOpcode +{ +#define TFL_NODE(OPCODE, CLASS) OPCODE, +#include "TFLNodes.lst" +#undef TFL_NODE +}; + +} // namespace locoex + +#endif // __LOCOEX_IR_TFLOPCODE_H__ diff --git a/compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.cpp b/compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.cpp new file mode 100644 index 000000000..2e71aa000 --- /dev/null +++ b/compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleShapeInferenceRule.h" + +#include "Dialect/IR/CircleNodes.h" +#include "Dialect/IR/CircleDialect.h" +#include "Dialect/IR/CircleNodeVisitor.h" + +#include "Check.h" + +#include <cassert> + +namespace +{ + +/** + * @brief Class to infer the shape of CircleNode + * + * @note All CircleNode's inputs and outputs are always loco::Domain::Tensor + */ +class ShapeInferenceAlgorithm final : public locoex::CircleNodeVisitor<loco::NodeShape> +{ +public: + loco::NodeShape visit(const locoex::CircleInstanceNorm *node) final + { + auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + + return loco::NodeShape{input_shape}; + } +}; + +} // namespace + +namespace locoex +{ + +bool CircleShapeInferenceRule::recognize(const loco::Dialect *d) const +{ + return CircleDialect::get() == d; +} + +bool CircleShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const +{ + assert(node->dialect() == CircleDialect::get()); + assert(dynamic_cast<const CircleNode *>(node) != nullptr); + + ShapeInferenceAlgorithm alg; + shape = dynamic_cast<const CircleNode *>(node)->accept(&alg); + + return true; +} + +} // namespace locoex diff --git a/compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.h b/compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.h new file mode 100644 index 000000000..92f23c9dd --- /dev/null +++ b/compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_SERVICE_CIRCLESHAPE_INFERENCE_RULE_H__ +#define __LOCOEX_SERVICE_CIRCLESHAPE_INFERENCE_RULE_H__ + +#include <loco/Service/ShapeInference.h> + +namespace locoex +{ + +struct CircleShapeInferenceRule final : public loco::ShapeInferenceRule +{ + bool recognize(const loco::Dialect *) const final; + bool infer(const loco::Node *, loco::NodeShape &) const final; +}; + +} // namespace locoex + +#endif // __LOCOEX_SERVICE_CIRCLESHAPE_INFERENCE_RULE_H__ diff --git a/compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.cpp b/compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.cpp new file mode 100644 index 000000000..6bc95a1b5 --- /dev/null +++ b/compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleTypeInferenceRule.h" + +#include "Dialect/IR/CircleDialect.h" +#include "Dialect/IR/CircleNodeVisitor.h" +#include "Dialect/IR/CircleNodes.h" + +#include <cassert> + +namespace +{ + +struct TypeInferenceAlgorithm final : public locoex::CircleNodeVisitor<loco::DataType> +{ + loco::DataType visit(const locoex::CircleInstanceNorm *node) final + { + return loco::dtype_get(node->input()); + } +}; + +} // namespace + +namespace locoex +{ + +bool CircleTypeInferenceRule::recognize(const loco::Dialect *d) const +{ + return CircleDialect::get() == d; +} + +bool CircleTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const +{ + assert(node->dialect() == CircleDialect::get()); + + TypeInferenceAlgorithm alg; + + dtype = dynamic_cast<const CircleNode *>(node)->accept(&alg); + assert(dtype != loco::DataType::Unknown); + + return true; +} + +} // namespace locoex diff --git a/compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.h b/compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.h new file mode 100644 index 000000000..c073dfc54 --- /dev/null +++ b/compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_SERVICE_CIRCLETYPE_INFERENCE_RULE_H__ +#define __LOCOEX_SERVICE_CIRCLETYPE_INFERENCE_RULE_H__ + +#include <loco/Service/TypeInference.h> + +namespace locoex +{ + +/** + * @brief Type Inference Rule for CircleDialect + */ +struct CircleTypeInferenceRule final : public loco::TypeInferenceRule +{ + bool recognize(const loco::Dialect *) const final; + bool infer(const loco::Node *, loco::DataType &) const final; +}; + +} // namespace locoex + +#endif // __LOCOEX_SERVICE_CIRCLETYPE_INFERENCE_RULE_H__ diff --git a/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp new file mode 100644 index 000000000..f4bb10364 --- /dev/null +++ b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp @@ -0,0 +1,627 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLShapeInferenceRule.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +#include "Check.h" + +#include <oops/InternalExn.h> + +#include <algorithm> +#include <cassert> +#include <stdexcept> + +namespace +{ + +// Call this for TFLAvgPool2D and TFLMaxPool2D only +template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node) +{ + EXO_ASSERT(loco::shape_known(node->value()), "Shape must be known"); + + auto ifm_shape = loco::shape_get(node->value()).template as<loco::TensorShape>(); + assert(ifm_shape.rank() == 4); + + uint32_t input_height = ifm_shape.dim(1).value(); + uint32_t input_width = ifm_shape.dim(2).value(); + uint32_t stride_height = node->stride()->h(); + uint32_t stride_width = node->stride()->w(); + uint32_t window_height = node->filter()->h(); + uint32_t window_width = node->filter()->w(); + uint32_t dilation_height = 1; // dilation for TFLAvgPool2D and TFLMaxPool2D is 1 + uint32_t dilation_width = 1; + uint32_t effective_window_height = dilation_height * (window_height - 1) + 1; + uint32_t effective_window_width = dilation_width * (window_width - 1) + 1; + + uint32_t output_height = 0; + uint32_t output_width = 0; + + if (node->padding() == locoex::Padding::VALID) + { + output_height = (input_height + stride_height - effective_window_height) / stride_height; + output_width = (input_width + stride_width - effective_window_width) / stride_width; + } + else if (node->padding() == locoex::Padding::SAME) + { + output_height = (input_height + stride_height - 1) / stride_height; + output_width = (input_width + stride_width - 1) / stride_width; + } + else + EXO_ASSERT(false, "Wrong padding type"); + + loco::TensorShape ofm_shape; + ofm_shape.rank(4); + ofm_shape.dim(0) = ifm_shape.dim(0); + ofm_shape.dim(1) = output_height; + ofm_shape.dim(2) = output_width; + ofm_shape.dim(3) = ifm_shape.dim(3); + + return loco::NodeShape{ofm_shape}; +} + +/** + * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics + * + * HOW TO USE: + * + * auto expanded_tensor_shape = expand(tensor_shape).to(N); + */ +class TensorShapeExpander +{ +public: + TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape} + { + // DO NOTHING + } + +public: + loco::TensorShape to(uint32_t output_rank) + { + auto const &input_shape = _shape; + uint32_t const input_rank = input_shape.rank(); + + assert(input_rank <= output_rank && "Cannot shrink rank"); + uint32_t const axis_shift = output_rank - input_rank; + + loco::TensorShape output_shape; + + output_shape.rank(output_rank); + for (uint32_t axis = 0; axis < output_rank; ++axis) + { + output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift); + } + + return output_shape; + } + +private: + const loco::TensorShape _shape; +}; + +/** + * @breif Expand shape x and y to same rank by align right and filling with 1 + */ +void expand_rank(loco::TensorShape &x, loco::TensorShape &y) +{ + auto x_rank = x.rank(); + auto y_rank = y.rank(); + + if (x_rank == y_rank) + return; + + TensorShapeExpander x_exp(x); + TensorShapeExpander y_exp(y); + + auto xy_rank = std::max(x_rank, y_rank); + + x = x_rank > y_rank ? x : x_exp.to(xy_rank); + y = y_rank > x_rank ? y : y_exp.to(xy_rank); +} + +/** + * @breif Returns shape of expanded dimension of input x and y having same rank + */ +loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y) +{ + assert(x.rank() == y.rank()); + + auto rank = x.rank(); + + loco::TensorShape output_shape; + + output_shape.rank(rank); + for (uint32_t axis = 0; axis < rank; ++axis) + { + assert(x.dim(axis).known() && y.dim(axis).known()); + + auto x_dim = x.dim(axis).value(); + auto y_dim = y.dim(axis).value(); + + // each dimension of x and y should be same or one must be 1 if different + if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1))) + INTERNAL_EXN("Cannot produce expand_dimension of two shapes"); + + output_shape.dim(axis) = std::max(x_dim, y_dim); + } + + return output_shape; +} + +loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y) +{ + auto x_match = x; + auto y_match = y; + + expand_rank(x_match, y_match); + + auto output_shape = expand_dimension(x_match, y_match); + + return output_shape; +} + +/** + * @brief Class to infer the shape of TFLNode + * + * @note All TFLNode's inputs and outputs are always loco::Domain::Tensor + */ +class ShapeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::NodeShape> +{ +public: + loco::NodeShape visit(const locoex::TFLAdd *node) final + { + auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>(); + auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>(); + + auto output_shape = broadcast_shape(x_shape, y_shape); + + return loco::NodeShape{output_shape}; + } + + loco::NodeShape visit(const locoex::TFLAveragePool2D *node) final + { + return infer_pool_2d_shape(node); + } + + loco::NodeShape visit(const locoex::TFLConcatenation *node) final + { + // TODO Support when TFLConcatenation has 0 input + assert(node->numValues() > 0); + + auto axis = node->axis(); + auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>(); + + loco::TensorShape output_shape; + + output_shape.rank(first_shape.rank()); + for (uint32_t i = 0; i < output_shape.rank(); ++i) + output_shape.dim(i) = first_shape.dim(i); + + for (uint32_t i = 1; i < node->numValues(); ++i) + { + auto input_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>(); + + for (uint32_t j = 0; j < output_shape.rank(); ++j) + { + if (j == axis) + output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value(); + else + assert(output_shape.dim(j) == input_shape.dim(j)); + } + } + + return loco::NodeShape{output_shape}; + } + + loco::NodeShape visit(const locoex::TFLConst *node) final + { + loco::TensorShape shape; + + shape.rank(node->rank()); + for (uint32_t axis = 0; axis < node->rank(); axis++) + shape.dim(axis) = node->dim(axis); + + return loco::NodeShape{shape}; + } + + loco::NodeShape visit(const locoex::TFLConv2D *node) final + { + auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC + auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI + + assert(ifm_shape.rank() == 4); + assert(ker_shape.rank() == 4); + assert(ifm_shape.dim(3) == ker_shape.dim(3)); + + uint32_t input_height = ifm_shape.dim(1).value(); + uint32_t input_width = ifm_shape.dim(2).value(); + uint32_t stride_height = node->stride()->h(); + uint32_t stride_width = node->stride()->w(); + uint32_t ker_height = ker_shape.dim(1).value(); + uint32_t ker_width = ker_shape.dim(2).value(); + uint32_t dilation_height = 1; + uint32_t dilation_width = 1; + uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1; + uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1; + + uint32_t output_height = 0; + uint32_t output_width = 0; + + if (node->padding() == locoex::Padding::VALID) + { + output_height = (input_height + stride_height - effective_ker_height) / stride_height; + output_width = (input_width + stride_width - effective_ker_width) / stride_width; + } + else if (node->padding() == locoex::Padding::SAME) + { + output_height = (input_height + stride_height - 1) / stride_height; + output_width = (input_width + stride_width - 1) / stride_width; + } + else + EXO_ASSERT(false, "Wrong padding type"); + + loco::TensorShape ofm_shape; + ofm_shape.rank(4); + ofm_shape.dim(0) = ifm_shape.dim(0); + ofm_shape.dim(1) = output_height; + ofm_shape.dim(2) = output_width; + ofm_shape.dim(3) = ker_shape.dim(0); + + return loco::NodeShape{ofm_shape}; + } + + loco::NodeShape visit(const locoex::TFLDepthwiseConv2D *node) final + { + auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC + auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM + + assert(ifm_shape.rank() == 4); + assert(ker_shape.rank() == 4); + assert(ker_shape.dim(0).value() == 1); + + uint32_t input_height = ifm_shape.dim(1).value(); + uint32_t input_width = ifm_shape.dim(2).value(); + uint32_t stride_height = node->stride()->h(); + uint32_t stride_width = node->stride()->w(); + uint32_t ker_height = ker_shape.dim(1).value(); + uint32_t ker_width = ker_shape.dim(2).value(); + uint32_t dilation_height = 1; + uint32_t dilation_width = 1; + uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1; + uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1; + + uint32_t output_height = 0; + uint32_t output_width = 0; + + if (node->padding() == locoex::Padding::VALID) + { + output_height = (input_height + stride_height - effective_ker_height) / stride_height; + output_width = (input_width + stride_width - effective_ker_width) / stride_width; + } + else if (node->padding() == locoex::Padding::SAME) + { + output_height = (input_height + stride_height - 1) / stride_height; + output_width = (input_width + stride_width - 1) / stride_width; + } + else + EXO_ASSERT(false, "Wrong padding type"); + + loco::TensorShape ofm_shape; + ofm_shape.rank(4); + ofm_shape.dim(0) = ifm_shape.dim(0); + ofm_shape.dim(1) = output_height; + ofm_shape.dim(2) = output_width; + ofm_shape.dim(3) = ker_shape.dim(3); + + return loco::NodeShape{ofm_shape}; + } + + loco::NodeShape visit(const locoex::TFLDiv *node) final + { + auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>(); + auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>(); + + auto output_shape = broadcast_shape(x_shape, y_shape); + + return loco::NodeShape{output_shape}; + } + + loco::NodeShape visit(const locoex::TFLFullyConnected *node) final + { + auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto weights_shape = loco::shape_get(node->weights()).as<loco::TensorShape>(); + + // Checking shape capability for multiplication + EXO_ASSERT(input_shape.rank() == 2, "NYI for input shape rank > 2"); + EXO_ASSERT(weights_shape.rank() == 2, "Incompatible weights rank for fully connected"); + EXO_ASSERT(input_shape.dim(1) == weights_shape.dim(1), + "Incompatible shapes for fully connected"); + + loco::TensorShape out_shape; + out_shape.rank(2); + + out_shape.dim(0) = input_shape.dim(0); + out_shape.dim(1) = weights_shape.dim(0); + + return loco::NodeShape{out_shape}; + } + + loco::NodeShape visit(const locoex::TFLMaximum *node) final + { + auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>(); + auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>(); + + auto output_shape = broadcast_shape(x_shape, y_shape); + + return loco::NodeShape{output_shape}; + } + + loco::NodeShape visit(const locoex::TFLMaxPool2D *node) final + { + return infer_pool_2d_shape(node); + } + + loco::NodeShape visit(const locoex::TFLMean *node) final + { + const loco::DataType S32 = loco::DataType::S32; + + auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto reduction_indices = dynamic_cast<locoex::TFLConst *>(node->reduction_indices()); + + { // Exceptions + // TODO support non-const case + EXO_ASSERT(reduction_indices, "Only support constant reduction_indices"); + // TODO support other data type + EXO_ASSERT(reduction_indices->dtype() == S32, "Only support int 32"); + } + + std::vector<int32_t> reduction_values; + + for (uint32_t i = 0; i < reduction_indices->size<S32>(); ++i) + { + int32_t axis = reduction_indices->at<S32>(i); + if (axis < 0) + axis += input_shape.rank(); + if (not(0 <= axis and axis < static_cast<int32_t>(input_shape.rank()))) + INTERNAL_EXN_V("Invalid reduction axis for MEAN", oops::to_uint32(axis)); + reduction_values.push_back(axis); + } + + loco::TensorShape output_shape; + + if (node->keep_dims()) + { + output_shape.rank(input_shape.rank()); + for (uint32_t i = 0; i < input_shape.rank(); ++i) + output_shape.dim(i) = input_shape.dim(i); + for (uint32_t i = 0; i < reduction_values.size(); ++i) + output_shape.dim(reduction_values.at(i)) = 1; + } + else + { + std::vector<bool> check_reduce(input_shape.rank(), false); + for (uint32_t i = 0; i < reduction_values.size(); ++i) + check_reduce.at(reduction_values.at(i)) = true; + + uint32_t reduce_cnt = 0; + for (uint32_t i = 0; i < check_reduce.size(); ++i) + if (check_reduce.at(i)) + ++reduce_cnt; + + output_shape.rank(input_shape.rank() - reduce_cnt); + for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i) + if (check_reduce.at(i) == false) + output_shape.dim(j++) = i; + } + + return loco::NodeShape{output_shape}; + } + + loco::NodeShape visit(const locoex::TFLMul *node) final + { + auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>(); + auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>(); + + auto output_shape = broadcast_shape(x_shape, y_shape); + + return loco::NodeShape{output_shape}; + } + + loco::NodeShape visit(const locoex::TFLRelu *node) final + { + auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>(); + + return loco::NodeShape{input_shape}; + } + + loco::NodeShape visit(const locoex::TFLRelu6 *node) final + { + auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>(); + + return loco::NodeShape{input_shape}; + } + + /** + * @note TFLReshape has new shape info in two places: 2nd input and attribute. + * This shape inference forces both to exist, and match each other. + * When this condition satisfied, it return the inferred shape + * + * TODO Change this policy when not appropriate + */ + loco::NodeShape visit(const locoex::TFLReshape *node) final + { + const loco::DataType S32 = loco::DataType::S32; + + loco::TensorShape shape_by_input; + { + EXO_ASSERT(node->shape(), "2nd input shape() should not be nullptr"); + + // Only support node's shape() is TFLConst with S32 + // TODO support other node with other types + auto const_shape_node = dynamic_cast<locoex::TFLConst *>(node->shape()); + EXO_ASSERT(const_shape_node, "Only support TFLConst for shape of TFLReshape"); + EXO_ASSERT(const_shape_node->dtype() == S32, "Only support int32 TFLConst"); + + if (const_shape_node->rank() != 1) + INTERNAL_EXN_V("Only support rank 1 TFLConst", oops::to_uint32(const_shape_node->rank())); + + shape_by_input.rank(const_shape_node->dim(0).value()); + + for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis) + { + EXO_ASSERT(const_shape_node->at<S32>(axis) > 0, "Dimension should be > 0") + shape_by_input.dim(axis) = const_shape_node->at<S32>(axis); + } + } + + loco::TensorShape shape_by_attr; + { + shape_by_attr.rank(node->newShape()->rank()); + + for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis) + { + EXO_ASSERT(node->newShape()->dim(axis) > 0, "Dimension should be > 0") + shape_by_attr.dim(axis) = node->newShape()->dim(axis); + } + } + + EXO_ASSERT(shape_by_input == shape_by_attr, + "Warning: Two new shape information mismatched for TFLReshape"); + + return loco::NodeShape{shape_by_input}; + } + + loco::NodeShape visit(const locoex::TFLRsqrt *node) final + { + auto input_shape = loco::shape_get(node->x()).as<loco::TensorShape>(); + + return loco::NodeShape{input_shape}; + } + + // TODO TFLSoftmax + + loco::NodeShape visit(const locoex::TFLSqrt *node) final + { + auto input_shape = loco::shape_get(node->x()).as<loco::TensorShape>(); + + return loco::NodeShape{input_shape}; + } + + loco::NodeShape visit(const locoex::TFLSquaredDifference *node) final + { + auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>(); + auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>(); + + auto output_shape = broadcast_shape(x_shape, y_shape); + + return loco::NodeShape{output_shape}; + } + + loco::NodeShape visit(const locoex::TFLSub *node) final + { + auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>(); + auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>(); + + auto output_shape = broadcast_shape(x_shape, y_shape); + + return loco::NodeShape{output_shape}; + } + + // TODO TFLTanh + + /// @brief Returns output shape of transpose. Use loco::ConstGen and locoex::TFLConst for ConstT. + template <class ConstT> + loco::TensorShape output_shape_of_transpose(loco::TensorShape input_shape, + const ConstT *perm_node) + { + loco::TensorShape output_shape; + output_shape.rank(input_shape.rank()); + + assert(perm_node->dtype() == loco::DataType::S32); + assert(input_shape.rank() == perm_node->template size<loco::DataType::S32>()); + + for (uint32_t out_axis = 0; out_axis < output_shape.rank(); out_axis++) + { + auto new_dim = perm_node->template at<loco::DataType::S32>(out_axis); + output_shape.dim(new_dim) = input_shape.dim(out_axis); + } + + return output_shape; + } + + loco::NodeShape visit(const locoex::TFLTranspose *node) final + { + auto input_shape = loco::shape_get(node->a()).as<loco::TensorShape>(); + + auto canon_perm = dynamic_cast<loco::ConstGen *>(node->perm()); + auto tfl_perm = dynamic_cast<locoex::TFLConst *>(node->perm()); + + if (canon_perm) + { + return loco::NodeShape{output_shape_of_transpose(input_shape, canon_perm)}; + } + else if (tfl_perm) + { + return loco::NodeShape{output_shape_of_transpose(input_shape, tfl_perm)}; + } + else + INTERNAL_EXN("perm of TFLTranspose should be either ConstGen or TFLConst"); + } + + loco::NodeShape visit(const locoex::TFLTransposeConv *node) final + { + // TransposeConv's output shape is written in its 'inputSizes' argument + auto input_sizes_const = dynamic_cast<locoex::TFLConst *>(node->inputSizes()); + EXO_ASSERT(input_sizes_const, "Only support when TFLTransposeConv's inputSizes is TFLConst") + EXO_ASSERT(input_sizes_const->dtype() == loco::DataType::S32, "Only support S32 dtype") + EXO_ASSERT(input_sizes_const->rank() == 1 && input_sizes_const->dim(0).value() == 4, + "Only support rank 1 with 4 entries") + + loco::TensorShape shape; + + shape.rank(4); + for (uint32_t axis = 0; axis < 4; ++axis) + shape.dim(axis) = input_sizes_const->at<loco::DataType::S32>(axis); + + return loco::NodeShape{shape}; + } +}; + +} // namespace + +namespace locoex +{ + +bool TFLShapeInferenceRule::recognize(const loco::Dialect *d) const +{ + return TFLDialect::get() == d; +} + +bool TFLShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const +{ + assert(node->dialect() == TFLDialect::get()); + assert(dynamic_cast<const TFLNode *>(node) != nullptr); + + ShapeInferenceAlgorithm alg; + shape = dynamic_cast<const TFLNode *>(node)->accept(&alg); + + return true; +} + +} // namespace locoex diff --git a/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.h b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.h new file mode 100644 index 000000000..434a145cc --- /dev/null +++ b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__ +#define __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__ + +#include <loco/Service/ShapeInference.h> + +namespace locoex +{ + +struct TFLShapeInferenceRule final : public loco::ShapeInferenceRule +{ + bool recognize(const loco::Dialect *) const final; + bool infer(const loco::Node *, loco::NodeShape &) const final; +}; + +} // namespace locoex + +#endif // __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__ diff --git a/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.test.cpp b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.test.cpp new file mode 100644 index 000000000..35c8f0b2a --- /dev/null +++ b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.test.cpp @@ -0,0 +1,277 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TestGraph.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/Service/TFLShapeInferenceRule.h" + +#include <loco.h> +#include <loco/IR/CanonicalDialect.h> +#include <loco/Service/ShapeInference.h> +#include <loco/Service/CanonicalShapeInferenceRule.h> +#include <loco/Service/MultiDialectShapeInferenceRule.h> + +#include <stdex/Memory.h> + +#include <gtest/gtest.h> + +TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu) +{ + // Create a simple network + exo::test::TestGraph graph; + auto tfl_node = graph.append<locoex::TFLRelu>(graph.pull); + graph.complete(tfl_node); + + // set shape + { + graph.pull->rank(2); + graph.pull->dim(0) = 3; + graph.pull->dim(1) = 4; + } + + // pre-check + ASSERT_FALSE(loco::shape_known(tfl_node)); + + // shape inference + locoex::TFLShapeInferenceRule tfl_rule; + loco::CanonicalShapeInferenceRule canonical_rule; + loco::MultiDialectShapeInferenceRule rules; + + rules.bind(loco::CanonicalDialect::get(), &canonical_rule) + .bind(locoex::TFLDialect::get(), &tfl_rule); + + loco::apply(&rules).to(graph.g.get()); + + // Verify + { + ASSERT_TRUE(loco::shape_known(tfl_node)); + ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor); + + auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>(); + ASSERT_EQ(shape.rank(), 2); + ASSERT_EQ(shape.dim(0), 3); + ASSERT_EQ(shape.dim(1), 4); + } +} + +// based on the case shown in +// https://www.corvil.com/kb/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-tensorflow +TEST(TFLShapeInferenceRuleTest, avgpool2d_valid) +{ + exo::test::TestGraph graph; + auto tfl_node = graph.append<locoex::TFLAveragePool2D>(graph.pull); + graph.complete(); + + auto pull = graph.pull; + { + pull->shape({1, 4, 3, 1}); + } + // setting TFLAveragePool2D + { + tfl_node->filter()->h(2); + tfl_node->filter()->w(2); + tfl_node->stride()->h(2); + tfl_node->stride()->w(2); + tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE); + tfl_node->padding(locoex::Padding::VALID); + } + ASSERT_FALSE(loco::shape_known(tfl_node)); + + // shape inference + locoex::TFLShapeInferenceRule tfl_rule; + loco::CanonicalShapeInferenceRule canonical_rule; + loco::MultiDialectShapeInferenceRule rules; + + rules.bind(loco::CanonicalDialect::get(), &canonical_rule) + .bind(locoex::TFLDialect::get(), &tfl_rule); + + loco::apply(&rules).to(graph.g.get()); + + // Verify + { + ASSERT_TRUE(loco::shape_known(tfl_node)); + ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor); + + auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>(); + ASSERT_EQ(shape.rank(), 4); + ASSERT_EQ(shape.dim(0).value(), 1); + ASSERT_EQ(shape.dim(1).value(), 2); + ASSERT_EQ(shape.dim(2).value(), 1); + ASSERT_EQ(shape.dim(3).value(), 1); + } +} + +TEST(TFLShapeInferenceRuleTest, avgpool2d_same) +{ + exo::test::TestGraph graph; + auto tfl_node = graph.append<locoex::TFLAveragePool2D>(graph.pull); + graph.complete(); + + auto pull = graph.pull; + { + pull->shape({1, 4, 3, 1}); + } + + // setting TFLAveragePool2D + { + tfl_node->filter()->h(2); + tfl_node->filter()->w(2); + tfl_node->stride()->h(2); + tfl_node->stride()->w(2); + tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE); + tfl_node->padding(locoex::Padding::SAME); + } + + ASSERT_FALSE(loco::shape_known(tfl_node)); + + // shape inference + locoex::TFLShapeInferenceRule tfl_rule; + loco::CanonicalShapeInferenceRule canonical_rule; + loco::MultiDialectShapeInferenceRule rules; + + rules.bind(loco::CanonicalDialect::get(), &canonical_rule) + .bind(locoex::TFLDialect::get(), &tfl_rule); + + loco::apply(&rules).to(graph.g.get()); + + // Verify + { + ASSERT_TRUE(loco::shape_known(tfl_node)); + ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor); + + auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>(); + ASSERT_EQ(shape.rank(), 4); + ASSERT_EQ(shape.dim(0).value(), 1); + ASSERT_EQ(shape.dim(1).value(), 2); + ASSERT_EQ(shape.dim(2).value(), 2); + ASSERT_EQ(shape.dim(3).value(), 1); + } +} + +/** + * @note Function to test: Shape inference of two different input shapes + * + * Rank expansion to higher input side + * x(2,1,5) + y(3,5) --> x(2,1,5) + y(1,3,5) + * Do output shape inference like numpy + * x(2,1,5) + y(1,3,5) --> output(2,3,5) + * For each axis, dim value should be same OR one of them should be 1 + */ +TEST(TFLShapeInferenceRuleTest, TFAdd_shapeinf_different) +{ + auto g = loco::make_graph(); + + auto x_node = g->nodes()->create<loco::Pull>(); + { + x_node->rank(3); + x_node->dim(0) = 2; + x_node->dim(1) = 1; + x_node->dim(2) = 5; + } + auto y_node = g->nodes()->create<loco::Pull>(); + { + y_node->rank(2); + y_node->dim(0) = 3; + y_node->dim(1) = 5; + } + auto tfl_node = g->nodes()->create<locoex::TFLAdd>(); + { + tfl_node->x(x_node); + tfl_node->y(y_node); + } + auto push_node = g->nodes()->create<loco::Push>(); + { + push_node->from(tfl_node); + } + + auto x_input = g->inputs()->create(); + { + x_input->name("x"); + loco::link(x_input, x_node); + } + auto y_input = g->inputs()->create(); + { + y_input->name("y"); + loco::link(y_input, y_node); + } + auto output = g->outputs()->create(); + { + output->name("output"); + loco::link(output, push_node); + } + + // pre-check + ASSERT_FALSE(loco::shape_known(tfl_node)); + + exo::ShapeInferencePass pass; + while (pass.run(g.get()) == true) + { + ; + } + + // Verify + { + ASSERT_TRUE(loco::shape_known(tfl_node)); + ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor); + + auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>(); + ASSERT_EQ(shape.rank(), 3); + ASSERT_EQ(shape.dim(0), 2); + ASSERT_EQ(shape.dim(1), 3); + ASSERT_EQ(shape.dim(2), 5); + } +} + +TEST(TFLShapeInferenceRuleTest, TFLTranspose_simple) +{ + exo::test::ExampleGraph<exo::test::ExampleGraphType::TFLTranspose> g; + + g.pull->rank(4); + g.pull->dim(0) = 10; + g.pull->dim(1) = 20; + g.pull->dim(2) = 30; + g.pull->dim(3) = 40; + + g.const_perm->dtype(loco::DataType::S32); + g.const_perm->rank(1); + g.const_perm->dim(0) = 4; + g.const_perm->size<loco::DataType::S32>(4); + g.const_perm->at<loco::DataType::S32>(0) = 2; + g.const_perm->at<loco::DataType::S32>(1) = 3; + g.const_perm->at<loco::DataType::S32>(2) = 0; + g.const_perm->at<loco::DataType::S32>(3) = 1; + + // pre-check + ASSERT_FALSE(loco::shape_known(g.tfl_transpose)); + + exo::ShapeInferencePass pass; + while (pass.run(g.graph()) == true) + ; + + // Verify + { + ASSERT_TRUE(loco::shape_known(g.tfl_transpose)); + + auto shape = loco::shape_get(g.tfl_transpose).as<loco::TensorShape>(); + ASSERT_EQ(shape.rank(), 4); + ASSERT_EQ(shape.dim(0), 30); + ASSERT_EQ(shape.dim(1), 40); + ASSERT_EQ(shape.dim(2), 10); + ASSERT_EQ(shape.dim(3), 20); + } +} diff --git a/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp new file mode 100644 index 000000000..3f123a6db --- /dev/null +++ b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLTypeInferenceRule.h" + +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/IR/TFLNodeVisitor.h" +#include "Dialect/IR/TFLNodes.h" + +#include <cassert> + +namespace +{ + +struct TypeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::DataType> +{ + loco::DataType visit(const locoex::TFLAdd *node) final { return loco::dtype_get(node->x()); } + + loco::DataType visit(const locoex::TFLAveragePool2D *node) final + { + return loco::dtype_get(node->value()); + } + + loco::DataType visit(const locoex::TFLConcatenation *node) final + { + // TODO Support when TFLConcatenation has 0 input + assert(node->numValues() > 0); + + for (uint32_t i = 1; i < node->numValues(); ++i) + assert(loco::dtype_get(node->values(i - 1)) == loco::dtype_get(node->values(i))); + + return loco::dtype_get(node->values(0)); + } + + loco::DataType visit(const locoex::TFLConst *node) final { return node->dtype(); } + + loco::DataType visit(const locoex::TFLConv2D *node) final + { + return loco::dtype_get(node->input()); + } + + loco::DataType visit(const locoex::TFLDepthwiseConv2D *node) final + { + return loco::dtype_get(node->input()); + } + + loco::DataType visit(const locoex::TFLDiv *node) final { return loco::dtype_get(node->x()); } + + loco::DataType visit(const locoex::TFLFullyConnected *node) final + { + return loco::dtype_get(node->input()); + } + + loco::DataType visit(const locoex::TFLMaximum *node) final { return loco::dtype_get(node->x()); } + + loco::DataType visit(const locoex::TFLMaxPool2D *node) final + { + return loco::dtype_get(node->value()); + } + + loco::DataType visit(const locoex::TFLMean *node) final { return loco::dtype_get(node->input()); } + + loco::DataType visit(const locoex::TFLMul *node) final { return loco::dtype_get(node->x()); } + + loco::DataType visit(const locoex::TFLRelu *node) final + { + return loco::dtype_get(node->features()); + } + + loco::DataType visit(const locoex::TFLRelu6 *node) final + { + return loco::dtype_get(node->features()); + } + + loco::DataType visit(const locoex::TFLReshape *node) final + { + return loco::dtype_get(node->tensor()); + } + + loco::DataType visit(const locoex::TFLRsqrt *node) final { return loco::dtype_get(node->x()); } + + // TODO TFLSoftmax + + loco::DataType visit(const locoex::TFLSqrt *node) final { return loco::dtype_get(node->x()); } + + loco::DataType visit(const locoex::TFLSquaredDifference *node) final + { + return loco::dtype_get(node->x()); + } + + loco::DataType visit(const locoex::TFLSub *node) final { return loco::dtype_get(node->x()); } + + // TODO TFLTanh + + loco::DataType visit(const locoex::TFLTranspose *node) final + { + return loco::dtype_get(node->a()); + } + + loco::DataType visit(const locoex::TFLTransposeConv *node) final + { + return loco::dtype_get(node->outBackprop()); + } +}; + +} // namespace + +namespace locoex +{ + +bool TFLTypeInferenceRule::recognize(const loco::Dialect *d) const +{ + return TFLDialect::get() == d; +} + +bool TFLTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const +{ + assert(node->dialect() == TFLDialect::get()); + + TypeInferenceAlgorithm alg; + + dtype = dynamic_cast<const TFLNode *>(node)->accept(&alg); + assert(dtype != loco::DataType::Unknown); + + return true; +} + +} // namespace locoex diff --git a/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.h b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.h new file mode 100644 index 000000000..31765dcba --- /dev/null +++ b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_SERVICE_TFLTYPE_INFERENCE_RULE_H__ +#define __LOCOEX_SERVICE_TFLTYPE_INFERENCE_RULE_H__ + +#include <loco/Service/TypeInference.h> + +namespace locoex +{ + +/** + * @brief Type Inference Rule for TFLDialect + */ +struct TFLTypeInferenceRule final : public loco::TypeInferenceRule +{ + bool recognize(const loco::Dialect *) const final; + + bool infer(const loco::Node *, loco::DataType &) const final; +}; + +} // namespace locoex + +#endif // __LOCOEX_SERVICE_TFLTYPE_INFERENCE_RULE_H__ diff --git a/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.test.cpp b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.test.cpp new file mode 100644 index 000000000..dd1f93c4d --- /dev/null +++ b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.test.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/Service/TFLTypeInferenceRule.h" + +#include "TestGraph.h" + +#include <loco.h> +#include <loco/IR/CanonicalDialect.h> +#include <loco/Service/TypeInference.h> + +#include <stdex/Memory.h> + +#include <gtest/gtest.h> + +TEST(TFLTypeInferenceRuleTest, minimal_with_TFLRelu) +{ + // Create a simple network + exo::test::TestGraph graph; + auto tfl_node = graph.append<locoex::TFLRelu>(graph.pull); + graph.complete(tfl_node); + + graph.pull->dtype(loco::DataType::S32); + + // pre-check + ASSERT_FALSE(loco::dtype_known(tfl_node)); + + // type inference + locoex::TFLTypeInferenceRule tfl_rule; + loco::CanonicalTypeInferenceRule canon_rule; + loco::MultiDialectTypeInferenceRule rules; + + rules.bind(loco::CanonicalDialect::get(), &canon_rule); + rules.bind(locoex::TFLDialect::get(), &tfl_rule); + + loco::apply(&rules).to(graph.g.get()); + + // Verify + ASSERT_TRUE(loco::dtype_known(tfl_node)); + auto type = loco::dtype_get(tfl_node); + ASSERT_EQ(type, loco::DataType::S32); +} diff --git a/compiler/exo/src/ExoFormattedGraph.cpp b/compiler/exo/src/ExoFormattedGraph.cpp new file mode 100644 index 000000000..5d3b18be1 --- /dev/null +++ b/compiler/exo/src/ExoFormattedGraph.cpp @@ -0,0 +1,525 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ExoFormattedGraph.h" + +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/IR/TFLNodes.h" + +#include "Dialect/IR/CircleDialect.h" +#include "Dialect/IR/CircleNodes.h" + +#include <locoex/Service/COpFormattedGraph.h> +#include <pepper/str.h> + +#include <sstream> +#include <cassert> + +// For TF lite +namespace +{ + +const char *to_str(locoex::FusedActFunc fused) +{ + switch (fused) + { + case locoex::FusedActFunc::NONE: + return "NONE"; + case locoex::FusedActFunc::RELU: + return "RELU"; + case locoex::FusedActFunc::RELU6: + return "RELU6"; + default: + return "Error"; + } +} + +const char *to_str(locoex::Padding padding) +{ + switch (padding) + { + case locoex::Padding::SAME: + return "SAME"; + case locoex::Padding::VALID: + return "VALID"; + default: + return "Error"; + } +} + +std::string to_str(const locoex::Stride *stride) +{ + return pepper::str(stride->h(), ",", stride->w()); +} + +std::string to_str(const locoex::Filter *filter) +{ + return pepper::str(filter->h(), ",", filter->w()); +} + +std::string tfl_opname(uint32_t opnum) +{ + static std::string prefix{"tfl."}; + + switch (static_cast<locoex::TFLOpcode>(opnum)) + { +#define TFL_NODE(OPCODE, CLASS) \ + case locoex::TFLOpcode::OPCODE: \ + return prefix + #OPCODE; +#include "Dialect/IR/TFLNodes.lst" +#undef TFL_NODE + default: + break; + }; + + return prefix + "Invalid"; +} + +// TFLNodeSummaryBuilder with default implementation +class TFLNodeSummaryBuilderBase : public locop::NodeSummaryBuilder +{ +public: + TFLNodeSummaryBuilderBase(const locop::SymbolTable *tbl) : _tbl{tbl} + { + // DO NOTHING + } + +public: + bool build(const loco::Node *, locop::NodeSummary &s) const final; + +protected: +#define TFL_NODE(OPCODE, CLASS) \ + virtual bool summary(const CLASS *, locop::NodeSummary &s) const \ + { \ + s.comments().append("Emitted by Default TFLNodeSummaryBuilder"); \ + s.state(locop::NodeSummary::State::PartiallyKnown); \ + return true; \ + } +#include "Dialect/IR/TFLNodes.lst" +#undef TFL_NODE + +protected: + const locop::SymbolTable *tbl(void) const { return _tbl; } + + // Please do not use _tbl directly and use tbl(). + // This will be changed to private in near future. +protected: + const locop::SymbolTable *_tbl; +}; + +class TFLNodeSummaryBuilder final : public TFLNodeSummaryBuilderBase +{ +public: + TFLNodeSummaryBuilder(const locop::SymbolTable *tbl) : TFLNodeSummaryBuilderBase(tbl) + { + // DO NOTHING + } + +private: +#define IMPLEMENT(CLASS) bool summary(const CLASS *, locop::NodeSummary &) const final; + IMPLEMENT(locoex::TFLAdd) + IMPLEMENT(locoex::TFLAveragePool2D) + IMPLEMENT(locoex::TFLConcatenation) + IMPLEMENT(locoex::TFLConst) + IMPLEMENT(locoex::TFLConv2D) + IMPLEMENT(locoex::TFLDepthwiseConv2D) + IMPLEMENT(locoex::TFLDiv) + IMPLEMENT(locoex::TFLMaximum) + IMPLEMENT(locoex::TFLMaxPool2D) + IMPLEMENT(locoex::TFLMean) + IMPLEMENT(locoex::TFLMul) + IMPLEMENT(locoex::TFLRelu) + IMPLEMENT(locoex::TFLRelu6) + IMPLEMENT(locoex::TFLReshape) + IMPLEMENT(locoex::TFLRsqrt) + IMPLEMENT(locoex::TFLSqrt) + IMPLEMENT(locoex::TFLSquaredDifference) + IMPLEMENT(locoex::TFLSub) + IMPLEMENT(locoex::TFLTranspose) + IMPLEMENT(locoex::TFLTransposeConv) +#undef IMPLEMENT +}; + +bool TFLNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSummary &s) const +{ + if (node->dialect() != locoex::TFLDialect::get()) + return false; + +#define TFL_NODE(OPCODE, CLASS) \ + if (dynamic_cast<const CLASS *>(node)) \ + { \ + s.opname(tfl_opname(node->opnum())); \ + return summary(dynamic_cast<const CLASS *>(node), s); \ + } +#include "Dialect/IR/TFLNodes.lst" +#undef TFL_NODE + + return false; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLAdd *node, locop::NodeSummary &s) const +{ + assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED); + + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.args().append("fused_activation_function", to_str(node->fusedActivationFunction())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLAveragePool2D *node, + locop::NodeSummary &s) const +{ + assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED); + + s.args().append("value", tbl()->lookup(node->value())); + s.args().append("filter(h,w)", to_str(node->filter())); + s.args().append("stride(h,w)", to_str(node->stride())); + s.args().append("padding", to_str(node->padding())); + s.args().append("fused", to_str(node->fusedActivationFunction())); + + s.state(locop::NodeSummary::State::Complete); + + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLConcatenation *node, + locop::NodeSummary &s) const +{ + assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED); + + for (uint32_t i = 0; i < node->numValues(); ++i) + s.args().append("values", tbl()->lookup(node->values(i))); + s.args().append("axis", pepper::str(node->axis())); + s.args().append("fused", to_str(node->fusedActivationFunction())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLConst *, locop::NodeSummary &s) const +{ + s.state(locop::NodeSummary::State::PartiallyKnown); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLConv2D *node, locop::NodeSummary &s) const +{ + assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED); + assert(node->padding() != locoex::Padding::UNDEFINED); + + s.args().append("input", tbl()->lookup(node->input())); + s.args().append("filter", tbl()->lookup(node->filter())); + s.args().append("bias", tbl()->lookup(node->bias())); + + s.args().append("stride(h,w)", to_str(node->stride())); + s.args().append("padding", to_str(node->padding())); + s.args().append("fused", to_str(node->fusedActivationFunction())); + + s.state(locop::NodeSummary::State::Complete); + + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLDepthwiseConv2D *node, + locop::NodeSummary &s) const +{ + assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED); + assert(node->padding() != locoex::Padding::UNDEFINED); + + s.args().append("input", tbl()->lookup(node->input())); + s.args().append("filter", tbl()->lookup(node->filter())); + s.args().append("bias", tbl()->lookup(node->bias())); + + s.args().append("stride(h,w)", to_str(node->stride())); + s.args().append("padding", to_str(node->padding())); + s.args().append("depthMultiplier", std::to_string(node->depthMultiplier())); + s.args().append("fused", to_str(node->fusedActivationFunction())); + + s.state(locop::NodeSummary::State::Complete); + + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLDiv *node, locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLMaximum *node, locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLMaxPool2D *node, locop::NodeSummary &s) const +{ + assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED); + + s.args().append("value", tbl()->lookup(node->value())); + s.args().append("filter(h,w)", to_str(node->filter())); + s.args().append("stride(h,w)", to_str(node->stride())); + s.args().append("padding", to_str(node->padding())); + s.args().append("fused", to_str(node->fusedActivationFunction())); + + s.state(locop::NodeSummary::State::Complete); + + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLMean *node, locop::NodeSummary &s) const +{ + s.args().append("input", tbl()->lookup(node->input())); + s.args().append("reduction_indices", tbl()->lookup(node->reduction_indices())); + s.args().append("keep_dims", node->keep_dims() ? "true" : "false"); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLMul *node, locop::NodeSummary &s) const +{ + assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED); + + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.args().append("fused_activation_function", to_str(node->fusedActivationFunction())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLRelu *node, locop::NodeSummary &s) const +{ + s.args().append("features", tbl()->lookup(node->features())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLRelu6 *node, locop::NodeSummary &s) const +{ + s.args().append("features", tbl()->lookup(node->features())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLReshape *node, locop::NodeSummary &s) const +{ + s.args().append("tensor", tbl()->lookup(node->tensor())); + s.args().append("shape", tbl()->lookup(node->shape())); + // TODO Show newShape info + s.state(locop::NodeSummary::State::PartiallyKnown); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLRsqrt *node, locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +// TODO TFLSoftmax + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLSqrt *node, locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLSquaredDifference *node, + locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLSub *node, locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +// TODO TFLTanh + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLTranspose *node, locop::NodeSummary &s) const +{ + s.args().append("a", tbl()->lookup(node->a())); + s.args().append("perm", tbl()->lookup(node->perm())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLTransposeConv *node, + locop::NodeSummary &s) const +{ + assert(node->padding() != locoex::Padding::UNDEFINED); + + s.args().append("inputSizes", tbl()->lookup(node->inputSizes())); + s.args().append("filter", tbl()->lookup(node->filter())); + s.args().append("outBackprop", tbl()->lookup(node->outBackprop())); + + s.args().append("stride(h,w)", to_str(node->stride())); + s.args().append("padding", to_str(node->padding())); + + s.state(locop::NodeSummary::State::Complete); + + return true; +} + +} // namespace + +// For Circle +namespace +{ + +std::string circle_opname(uint32_t opnum) +{ + static std::string prefix{"circle."}; + + switch (static_cast<locoex::CircleOpcode>(opnum)) + { +#define CIRCLE_NODE(OPCODE, CLASS) \ + case locoex::CircleOpcode::OPCODE: \ + return prefix + #OPCODE; +#include "Dialect/IR/CircleNodes.lst" +#undef CIRCLE_NODE + default: + break; + }; + + return prefix + "Invalid"; +} + +// CircleNodeSummaryBuilder with default implementation +class CircleNodeSummaryBuilderBase : public locop::NodeSummaryBuilder +{ +public: + CircleNodeSummaryBuilderBase(const locop::SymbolTable *tbl) : _tbl{tbl} + { + // DO NOTHING + } + +public: + bool build(const loco::Node *, locop::NodeSummary &s) const final; + +protected: +#define CIRCLE_NODE(OPCODE, CLASS) \ + virtual bool summary(const CLASS *, locop::NodeSummary &s) const \ + { \ + s.comments().append("Emitted by Default CircleNodeSummaryBuilder"); \ + s.state(locop::NodeSummary::State::PartiallyKnown); \ + return true; \ + } +#include "Dialect/IR/CircleNodes.lst" +#undef CIRCLE_NODE + +protected: + const locop::SymbolTable *tbl(void) const { return _tbl; } + + // Please do not use _tbl directly and use tbl(). + // This will be changed to private in near future. +protected: + const locop::SymbolTable *_tbl; +}; + +class CircleNodeSummaryBuilder final : public CircleNodeSummaryBuilderBase +{ +public: + CircleNodeSummaryBuilder(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl) + { + // DO NOTHING + } + +private: +#define IMPLEMENT(CLASS) bool summary(const CLASS *, locop::NodeSummary &) const final; + IMPLEMENT(locoex::CircleInstanceNorm) +#undef IMPLEMENT +}; + +bool CircleNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSummary &s) const +{ + if (node->dialect() != locoex::CircleDialect::get()) + return false; + +#define CIRCLE_NODE(OPCODE, CLASS) \ + if (dynamic_cast<const CLASS *>(node)) \ + { \ + s.opname(circle_opname(node->opnum())); \ + return summary(dynamic_cast<const CLASS *>(node), s); \ + } +#include "Dialect/IR/CircleNodes.lst" +#undef CIRCLE_NODE + + return false; +} + +bool CircleNodeSummaryBuilder::summary(const locoex::CircleInstanceNorm *node, + locop::NodeSummary &s) const +{ + auto fused = node->fusedActivationFunction(); + assert(fused != locoex::FusedActFunc::UNDEFINED); + + s.args().append("input", tbl()->lookup(node->input())); + s.args().append("gamma", tbl()->lookup(node->gamma())); + s.args().append("beta", tbl()->lookup(node->beta())); + s.args().append("epsilon", pepper::str(node->epsilon())); + s.args().append("fused_activation_function", to_str(fused)); + + s.state(locop::NodeSummary::State::Complete); + + return true; +} + +} // namespace + +namespace exo +{ + +bool NodeSummaryBuilder::build(const loco::Node *node, locop::NodeSummary &s) const +{ + if (locop::CanonicalNodeSummaryBuilder(_tbl).build(node, s)) + { + return true; + } + + if (TFLNodeSummaryBuilder(_tbl).build(node, s)) + { + return true; + } + + if (CircleNodeSummaryBuilder(_tbl).build(node, s)) + { + return true; + } + + if (locoex::COpNodeSummaryBuilder(_tbl).build(node, s)) + { + return true; + } + + return false; +} + +} // namespace exo diff --git a/compiler/exo/src/ExoFormattedGraph.h b/compiler/exo/src/ExoFormattedGraph.h new file mode 100644 index 000000000..714e483b5 --- /dev/null +++ b/compiler/exo/src/ExoFormattedGraph.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __EXO_FORMATTED_GRAPH_H__ +#define __EXO_FORMATTED_GRAPH_H__ + +#include <locop/FormattedGraph.h> + +#include <stdex/Memory.h> + +namespace exo +{ + +class NodeSummaryBuilder final : public locop::NodeSummaryBuilder +{ +public: + NodeSummaryBuilder(const locop::SymbolTable *tbl) : _tbl{tbl} + { + // DO NOTHING + } + +public: + bool build(const loco::Node *node, locop::NodeSummary &s) const final; + +private: + const locop::SymbolTable *_tbl; +}; + +class NodeSummaryBuilderFactory final : public locop::NodeSummaryBuilderFactory +{ +public: + NodeSummaryBuilderFactory() = default; + +public: + std::unique_ptr<locop::NodeSummaryBuilder> create(const locop::SymbolTable *tlb) const final + { + return stdex::make_unique<NodeSummaryBuilder>(tlb); + } +}; + +} // namespace exo + +#endif // __EXO_FORMATTED_GRAPH_H__ diff --git a/compiler/exo/src/ExoOptimize.cpp b/compiler/exo/src/ExoOptimize.cpp new file mode 100644 index 000000000..d7278e900 --- /dev/null +++ b/compiler/exo/src/ExoOptimize.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ExoOptimize.h" + +#include "Knob.h" +#include "Passes.h" +#include "ProgressReporter.h" + +#include <logo/Phase.h> + +#include <stdex/Memory.h> + +namespace exo +{ + +void optimize(loco::Graph *g) +{ + logo::Phase phase; + { + // prepare type and shape before optimization + phase.emplace_back(stdex::make_unique<TypeInferencePass>()); + phase.emplace_back(stdex::make_unique<ShapeInferencePass>()); + + phase.emplace_back(stdex::make_unique<FoldReshapeOfConstPass>()); + phase.emplace_back(stdex::make_unique<FoldTransposeOfConstPass>()); + + if (get<Knob::UseFuseBiasAddPass>()) + { + phase.emplace_back(stdex::make_unique<FuseBiasAddPass>()); + } + + if (get<Knob::UseFuseInstanceNormPass>()) + { + phase.emplace_back(stdex::make_unique<FuseInstanceNormPass>()); + } + + if (get<Knob::UseFuseReluPass>()) + { + phase.emplace_back(stdex::make_unique<FuseReluPass>()); + } + phase.emplace_back(stdex::make_unique<FuseRsqrtPass>()); + + if (get<Knob::UseFuseSquaredDifferencePass>()) + { + phase.emplace_back(stdex::make_unique<FuseSquaredDifferencePass>()); + } + + phase.emplace_back(stdex::make_unique<MergeConcatNodesPass>()); + + phase.emplace_back(stdex::make_unique<logo::RemoveDeadNodePass>()); + } + + logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g}; + + ProgressReporter prog(g, logo::PhaseStrategy::Restart); + phase_runner.attach(&prog); + phase_runner.run(phase); +} + +} // namespace exo diff --git a/compiler/exo/src/ExoOptimize.h b/compiler/exo/src/ExoOptimize.h new file mode 100644 index 000000000..4769c1193 --- /dev/null +++ b/compiler/exo/src/ExoOptimize.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OPTIMIZE_H__ +#define __OPTIMIZE_H__ + +#include <loco.h> + +namespace exo +{ + +/** + * @brief Run passes for a graph after completion of converting canonical nodes into TFL nodes. + * + * TODO Separate optimize pass dedicated to TFL and Circle dialect when necessary + */ +void optimize(loco::Graph *); + +} // namespace exo + +#endif // __OPTIMIZE_H__ diff --git a/compiler/exo/src/ExporterUtils.cpp b/compiler/exo/src/ExporterUtils.cpp new file mode 100644 index 000000000..41ccdcd71 --- /dev/null +++ b/compiler/exo/src/ExporterUtils.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ExporterUtils.h" + +#include <oops/InternalExn.h> + +#include <cassert> + +namespace exo +{ + +ShapeDescription to_shape_description(const loco::TensorShape &shape) +{ + ShapeDescription res; + + res._rank_known = true; + + res._dims.resize(shape.rank()); + for (uint32_t axis = 0; axis < shape.rank(); ++axis) + { + // All the dimensions SHOULD be known + assert(shape.dim(axis).known()); + res._dims.at(axis) = shape.dim(axis).value(); + } + + return res; +} + +ShapeDescription to_shape_description(const loco::FeatureShape &shape) +{ + ShapeDescription res; + + res._rank_known = true; + + // T/F Lite encodes a feature map as a NHWC tensor + res._dims.resize(4); + res._dims.at(0) = shape.count().value(); + res._dims.at(1) = shape.height().value(); + res._dims.at(2) = shape.width().value(); + res._dims.at(3) = shape.depth().value(); + + return res; +} + +ShapeDescription to_shape_description(const loco::FilterShape &shape) +{ + ShapeDescription res; + + res._rank_known = true; + + // T/F Lite encodes a convolution filter as a NHWC tensor + res._dims.resize(4); + res._dims.at(0) = shape.count().value(); + res._dims.at(1) = shape.height().value(); + res._dims.at(2) = shape.width().value(); + res._dims.at(3) = shape.depth().value(); + + return res; +} + +ShapeDescription to_shape_description(const loco::DepthwiseFilterShape &shape) +{ + ShapeDescription res; + + res._rank_known = true; + + // T/F Lite encodes a depthwise convolution filter as a [1, H, W, C*M] tensor + res._dims.resize(4); + res._dims.at(0) = 1; + res._dims.at(1) = shape.height().value(); + res._dims.at(2) = shape.width().value(); + res._dims.at(3) = shape.depth().value() * shape.multiplier().value(); + + return res; +} + +ShapeDescription to_shape_description(const loco::BiasShape &shape) +{ + ShapeDescription res; + + res._rank_known = true; + + res._dims.resize(1); + res._dims.at(0) = shape.length().value(); + + return res; +} + +ShapeDescription to_shape_description(const loco::MatrixShape &shape) +{ + ShapeDescription res; + + res._rank_known = true; + + res._dims.resize(2); + res._dims.at(0) = shape.height().value(); + res._dims.at(1) = shape.width().value(); + + return res; +} + +ShapeDescription to_shape_description(const loco::NodeShape &shape) +{ + switch (shape.domain()) + { + case loco::Domain::Tensor: + return to_shape_description(shape.as<loco::TensorShape>()); + case loco::Domain::Feature: + return to_shape_description(shape.as<loco::FeatureShape>()); + case loco::Domain::Filter: + return to_shape_description(shape.as<loco::FilterShape>()); + case loco::Domain::DepthwiseFilter: + return to_shape_description(shape.as<loco::DepthwiseFilterShape>()); + case loco::Domain::Bias: + return to_shape_description(shape.as<loco::BiasShape>()); + case loco::Domain::Matrix: + return to_shape_description(shape.as<loco::MatrixShape>()); + default: + break; + } + + INTERNAL_EXN_V("Unsupported loco domain", oops::to_uint32(shape.domain())); +} + +} // namespace exo diff --git a/compiler/exo/src/ExporterUtils.h b/compiler/exo/src/ExporterUtils.h new file mode 100644 index 000000000..e1f1f66a8 --- /dev/null +++ b/compiler/exo/src/ExporterUtils.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __EXPORTER_UTILS_H__ +#define __EXPORTER_UTILS_H__ + +#include "loco.h" + +#include "loco/IR/PermutingCodec.h" +#include "loco/IR/NodeShape.h" + +namespace exo +{ + +struct ShapeDescription +{ + std::vector<int32_t> _dims; + bool _rank_known; +}; + +ShapeDescription to_shape_description(const loco::TensorShape &shape); +ShapeDescription to_shape_description(const loco::FeatureShape &shape); +ShapeDescription to_shape_description(const loco::FilterShape &shape); +ShapeDescription to_shape_description(const loco::BiasShape &shape); +ShapeDescription to_shape_description(const loco::MatrixShape &shape); +ShapeDescription to_shape_description(const loco::NodeShape &shape); + +template <typename Permutation> inline bool isNHWC(Permutation *perm); + +template <> inline bool isNHWC(loco::Permutation<loco::Domain::Feature> *perm) +{ + return perm->axis(loco::FeatureAxis::Count) == 0 && perm->axis(loco::FeatureAxis::Height) == 1 && + perm->axis(loco::FeatureAxis::Width) == 2 && perm->axis(loco::FeatureAxis::Depth) == 3; +} + +template <> inline bool isNHWC(loco::Permutation<loco::Domain::Filter> *perm) +{ + return perm->axis(loco::FilterAxis::Count) == 0 && perm->axis(loco::FilterAxis::Height) == 1 && + perm->axis(loco::FilterAxis::Width) == 2 && perm->axis(loco::FilterAxis::Depth) == 3; +} + +} // namespace exo + +#endif // __EXPORTER_UTILS_H__ diff --git a/compiler/exo/src/GraphBlock.cpp b/compiler/exo/src/GraphBlock.cpp new file mode 100644 index 000000000..0a45ce8ad --- /dev/null +++ b/compiler/exo/src/GraphBlock.cpp @@ -0,0 +1,243 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "GraphBlock.h" + +#include "Check.h" + +#include <loco.h> +#include <stdex/Memory.h> + +namespace +{ + +template <exo::FeatureLayout T> loco::Permutation<loco::Domain::Feature> perm(); + +template <> loco::Permutation<loco::Domain::Feature> perm<exo::FeatureLayout::NHWC>() +{ + // Make NHWC permutation for encoder and decoder + loco::Permutation<loco::Domain::Feature> NHWC; + + NHWC.axis(loco::FeatureAxis::Count) = 0; + NHWC.axis(loco::FeatureAxis::Height) = 1; + NHWC.axis(loco::FeatureAxis::Width) = 2; + NHWC.axis(loco::FeatureAxis::Depth) = 3; + + return NHWC; +} + +template <exo::FilterLayout T> loco::Permutation<loco::Domain::Filter> perm(); + +template <> loco::Permutation<loco::Domain::Filter> perm<exo::FilterLayout::HWIO>() +{ + loco::Permutation<loco::Domain::Filter> HWIO; // a.k.a., HWCN + + HWIO.axis(loco::FilterAxis::Height) = 0; + HWIO.axis(loco::FilterAxis::Width) = 1; + HWIO.axis(loco::FilterAxis::Depth) = 2; + HWIO.axis(loco::FilterAxis::Count) = 3; + + return HWIO; +} + +template <> loco::Permutation<loco::Domain::Filter> perm<exo::FilterLayout::OHWI>() +{ + + // Make NHWC permutation for encoder and decoder + loco::Permutation<loco::Domain::Filter> OHWI; // a.k.a., NHWC + + OHWI.axis(loco::FilterAxis::Count) = 0; + OHWI.axis(loco::FilterAxis::Height) = 1; + OHWI.axis(loco::FilterAxis::Width) = 2; + OHWI.axis(loco::FilterAxis::Depth) = 3; + + return OHWI; +} + +template <exo::DepthwiseFilterLayout T> loco::Permutation<loco::Domain::DepthwiseFilter> perm(); + +template <> +loco::Permutation<loco::Domain::DepthwiseFilter> perm<exo::DepthwiseFilterLayout::HWCM>() +{ + loco::Permutation<loco::Domain::DepthwiseFilter> HWCM; + + HWCM.axis(loco::DepthwiseFilterAxis::Height) = 0; + HWCM.axis(loco::DepthwiseFilterAxis::Width) = 1; + HWCM.axis(loco::DepthwiseFilterAxis::Depth) = 2; + HWCM.axis(loco::DepthwiseFilterAxis::Multiplier) = 3; + + return HWCM; +} + +template <exo::MatrixLayout T> loco::Permutation<loco::Domain::Matrix> perm(); + +template <> loco::Permutation<loco::Domain::Matrix> perm<exo::MatrixLayout::HW>() +{ + loco::Permutation<loco::Domain::Matrix> HW; + + HW.axis(loco::MatrixAxis::Height) = 0; + HW.axis(loco::MatrixAxis::Width) = 1; + + return HW; +} + +template <> loco::Permutation<loco::Domain::Matrix> perm<exo::MatrixLayout::WH>() +{ + loco::Permutation<loco::Domain::Matrix> WH; + + WH.axis(loco::MatrixAxis::Height) = 1; + WH.axis(loco::MatrixAxis::Width) = 0; + + return WH; +} + +} // namespace + +namespace exo +{ + +template <FeatureLayout T> loco::FeatureEncode *make_feature_encode(loco::Node *input_for_encode) +{ + EXO_ASSERT(input_for_encode != nullptr, "input should not be nullptr"); + loco::Graph *g = input_for_encode->graph(); + + auto encoder = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>(); + + encoder->perm(perm<T>()); + + auto enc = g->nodes()->create<loco::FeatureEncode>(); + enc->input(input_for_encode); + enc->encoder(std::move(encoder)); + + return enc; +} + +template <FeatureLayout T> loco::FeatureDecode *make_feature_decode(loco::Node *input_for_decode) +{ + EXO_ASSERT(input_for_decode != nullptr, "input should not be nullptr"); + loco::Graph *g = input_for_decode->graph(); + + auto decoder = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>(); + + decoder->perm(perm<T>()); + + auto dec = g->nodes()->create<loco::FeatureDecode>(); + dec->input(input_for_decode); + dec->decoder(std::move(decoder)); + + return dec; +} + +template <FilterLayout T> loco::FilterEncode *make_filter_encode(loco::Node *input_for_encode) +{ + EXO_ASSERT(input_for_encode != nullptr, "filter should not be nullptr"); + loco::Graph *g = input_for_encode->graph(); + + auto encoder = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>(); + + encoder->perm(perm<T>()); + + auto enc = g->nodes()->create<loco::FilterEncode>(); + enc->input(input_for_encode); + enc->encoder(std::move(encoder)); + + return enc; +} + +template <FilterLayout T> loco::FilterDecode *make_filter_decode(loco::Node *input_for_decode) +{ + EXO_ASSERT(input_for_decode != nullptr, "filter should not be nullptr"); + loco::Graph *g = input_for_decode->graph(); + + auto decoder = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Filter>>(); + + decoder->perm(perm<T>()); + + auto dec = g->nodes()->create<loco::FilterDecode>(); + dec->input(input_for_decode); + dec->decoder(std::move(decoder)); + + return dec; +} + +template <DepthwiseFilterLayout T> +loco::DepthwiseFilterDecode *make_dw_filter_decode(loco::Node *input_for_decode) +{ + EXO_ASSERT(input_for_decode != nullptr, "filter should not be nullptr"); + loco::Graph *g = input_for_decode->graph(); + + auto decoder = stdex::make_unique<loco::PermutingDecoder<loco::Domain::DepthwiseFilter>>(); + + decoder->perm(perm<T>()); + + auto dec = g->nodes()->create<loco::DepthwiseFilterDecode>(); + dec->input(input_for_decode); + dec->decoder(std::move(decoder)); + + return dec; +} + +template <MatrixLayout T> loco::MatrixEncode *make_matrix_encode(loco::Node *input_for_encode) +{ + EXO_ASSERT(input_for_encode != nullptr, "input should not be nullptr"); + loco::Graph *g = input_for_encode->graph(); + + auto encoder = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Matrix>>(); + + encoder->perm(perm<T>()); + + auto enc = g->nodes()->create<loco::MatrixEncode>(); + enc->input(input_for_encode); + enc->encoder(std::move(encoder)); + + return enc; +} + +template <MatrixLayout T> loco::MatrixDecode *make_matrix_decode(loco::Node *input_for_decode) +{ + EXO_ASSERT(input_for_decode != nullptr, "input should not be nullptr"); + loco::Graph *g = input_for_decode->graph(); + + auto decoder = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Matrix>>(); + + decoder->perm(perm<T>()); + + auto dec = g->nodes()->create<loco::MatrixDecode>(); + dec->input(input_for_decode); + dec->decoder(std::move(decoder)); + + return dec; +} + +// template instantiation +template loco::FeatureEncode * +make_feature_encode<FeatureLayout::NHWC>(loco::Node *input_for_encode); + +template loco::FeatureDecode * +make_feature_decode<FeatureLayout::NHWC>(loco::Node *input_for_encode); + +template loco::FilterEncode *make_filter_encode<FilterLayout::HWIO>(loco::Node *input_for_encode); +template loco::FilterDecode *make_filter_decode<FilterLayout::OHWI>(loco::Node *input_for_decode); + +template loco::DepthwiseFilterDecode * +make_dw_filter_decode<DepthwiseFilterLayout::HWCM>(loco::Node *input_for_decode); + +template loco::MatrixEncode *make_matrix_encode<MatrixLayout::HW>(loco::Node *input_for_encode); +template loco::MatrixEncode *make_matrix_encode<MatrixLayout::WH>(loco::Node *input_for_encode); +template loco::MatrixDecode *make_matrix_decode<MatrixLayout::HW>(loco::Node *input_for_decode); +template loco::MatrixDecode *make_matrix_decode<MatrixLayout::WH>(loco::Node *input_for_decode); + +} // namespace exo diff --git a/compiler/exo/src/GraphBlock.h b/compiler/exo/src/GraphBlock.h new file mode 100644 index 000000000..b771c821b --- /dev/null +++ b/compiler/exo/src/GraphBlock.h @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __GRAPH_BLOCK_H__ +#define __GRAPH_BLOCK_H__ + +#include <loco.h> +#include <loco/Service/ShapeInference.h> + +#include <oops/InternalExn.h> + +#include <functional> + +namespace exo +{ + +/// @brief feature layout of TFLITE file +enum class FeatureLayout +{ + NHWC, +}; + +/// @brief Creates a loco::FeatureEncode with T layout (NHWC for tflite) and add it to graph. +template <FeatureLayout T> loco::FeatureEncode *make_feature_encode(loco::Node *input_for_encode); + +/// @brief Creates a loco::FeatureDecode with T layout (NHWC for tflite) and add it to graph. +template <FeatureLayout T> loco::FeatureDecode *make_feature_decode(loco::Node *input_for_decode); + +enum class FilterLayout +{ + OHWI, // a.k.a., NHWC, Tensorflow Lite uses this layout for filter + HWIO, // a.k.a., HWCN, Tensorflow uses this layout for filter +}; + +/// @brief Create a loco::FilterEncode of given layout +template <FilterLayout T> loco::FilterEncode *make_filter_encode(loco::Node *input_for_encode); + +/// @brief Create a loco::FilterDecode of given layout +template <FilterLayout T> loco::FilterDecode *make_filter_decode(loco::Node *input_for_decode); + +enum class DepthwiseFilterLayout +{ + HWCM, +}; + +/// @brief Create a loco::DepthwiseFilterDecode of given layout +template <DepthwiseFilterLayout T> +loco::DepthwiseFilterDecode *make_dw_filter_decode(loco::Node *input_for_decode); + +enum class MatrixLayout +{ + HW, + WH +}; + +/// @brief Create a loco::MatrixEncode of given layout +template <MatrixLayout T> loco::MatrixEncode *make_matrix_encode(loco::Node *input_for_encode); + +/// @brief Create a loco::MatrixDecode of given layout +template <MatrixLayout T> loco::MatrixDecode *make_matrix_decode(loco::Node *input_for_decode); + +} // exo + +// +// DomainConverter +// + +/** + * Some canonical nodes can have input of various loco::Domain, e.g., loco::Domain::Tensor, + * loco::Domain::Feature, etc. However, TFL node accepts only loco::Domain::Tensor. + * So, When converting such canonical node to TFL node and input(s) of a canonical node are not + * loco::Domain::Tensor, additional nodes need to be inserted. + * + * The following two classes helps this insertion. + * + * For example, in case of loco::Relu conversion, + * + * Before: + * + * A (output: feature) -- loco::ReLU --- B (input:feature) + * + * After: + * + * A -- loco::FeatureDecode -- locoex::TFLRelu -- loco::FeatureEncode --- B + * + * loco::ReLU (dead node) + */ + +namespace exo +{ + +/** + * @brief Handles input(s) while converting a canonical node to TFL node(s). + * This class informs DomainConverter how to handle inputs of a specific canonical node. + */ +template <class CanonicalT, class TFLT> class InputHandler +{ +public: + /** + * @brief Assign origin's inputs to replacer's inputs. + * (This is called when origin belongs in Tensor domain.) + */ + virtual void handover(CanonicalT *origin, TFLT *replacer) = 0; + + /** + * @brief Returns the list of inputs that needs to have FeatureDecode as its input. + * (This is called when origin belongs in Feature domain.) + */ + virtual std::vector<loco::Node *> getInputsToConvert(CanonicalT *origin) = 0; + + /// @brief Set the inputs of replacer to new_inputs + virtual void set(TFLT *replacer, std::vector<loco::Node *> &new_inputs) = 0; + + /// @brief Set the inputs to nullptr + virtual void nullify(CanonicalT *origin) = 0; +}; + +/** + * @brief Class to handle domain conversion while converting a canonical node to TFL node(s) + */ +template <class CanonicalT, class TFLT> class DomainConverter +{ +public: + template <FeatureLayout FeatureLayoutT> + TFLT *convert(CanonicalT *origin, InputHandler<CanonicalT, TFLT> &input_handler); +}; + +/** + * @brief Performs domain conversion + * + * 1. if origin belong to loco::Domain::Tensor, and replace origin to a TFL node. + * 2. if origin belong to loco::Domain::Feature, insert loco::FeatureDecode for input(s) and + * insert loco::FeatureEncode for output. Then replace origin to a TFL node. + * + * @return new TFL node; nullptr if shape of origin cannot be known + */ +template <class CanonicalT, class TFLT> +template <FeatureLayout FeatureLayoutT> +TFLT *DomainConverter<CanonicalT, TFLT>::convert(CanonicalT *origin, + InputHandler<CanonicalT, TFLT> &input_handler) +{ + static_assert(FeatureLayoutT == FeatureLayout::NHWC, "Feature layout should be NHWC"); + + if (!loco::shape_known(origin)) + { + return nullptr; + } + + auto tfl_node = origin->graph()->nodes()->template create<TFLT>(); + + // when the input is Tensor, just replace canonical node to TFL node. + if (loco::shape_get(origin).domain() == loco::Domain::Tensor) + { + input_handler.handover(origin, tfl_node); + + loco::replace(origin).with(tfl_node); + input_handler.nullify(origin); + + return tfl_node; + } + else if (loco::shape_get(origin).domain() == loco::Domain::Feature) + { + std::vector<loco::Node *> feature_decodes; + + for (auto input : input_handler.getInputsToConvert(origin)) + { + auto dec = make_feature_decode<FeatureLayoutT>(input); + feature_decodes.emplace_back(dec); + } + + input_handler.set(tfl_node, feature_decodes); + + auto enc = make_feature_encode<FeatureLayoutT>(tfl_node); + + loco::replace(origin).with(enc); + input_handler.nullify(origin); + + return tfl_node; + } + else + INTERNAL_EXN_V("Unsupported loco::Domain", oops::to_uint32(loco::shape_get(origin).domain())); +} + +} // namespace exo + +#endif //__GRAPH_BLOCK_H__ diff --git a/compiler/exo/src/Knob.cpp b/compiler/exo/src/Knob.cpp new file mode 100644 index 000000000..50d78f4b7 --- /dev/null +++ b/compiler/exo/src/Knob.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Knob.h" + +#include <pepper/strcast.h> + +#include <iostream> +#include <string> +#include <map> + +// Basic Infrastructure to declare and access Knob values +namespace +{ + +using KnobName = std::string; + +/** + * @brief Load configuration (from somewhere) + */ +struct KnobLoader +{ + virtual ~KnobLoader() = default; + + virtual bool load(const KnobName &name, bool default_value) const = 0; +}; + +/** + * @brief Load configuration from environment variables + * + * Given a prefix P, EnvKnobLoader reads a configuration K from concat(P, K). + * + * For example, let us assume that P is "MY_" and K is "CONFIG". + * + * Then, EnvKnobLoader reads configuration CONFIG from environment variable MY_CONFIG. + */ +class EnvKnobLoader final : public KnobLoader +{ +public: + EnvKnobLoader() = default; + +public: + bool load(const KnobName &knob_name, bool default_value) const override + { + auto envvar = _prefix + knob_name; + auto s = std::getenv(envvar.c_str()); + + return pepper::safe_strcast<int>(s, default_value ? 1 : 0) != 0; + } + void knob_set(const KnobName &knob_name, bool value) { _knob[knob_name] = value; } + void dialect_set(const exo::Dialect &dialect_name) { _prefix = _label[dialect_name]; } + bool knob_get(const KnobName &knob_name) { return load(knob_name, _knob[knob_name]); } + +private: + /// @brief Environment variable prefix + std::string _prefix; + std::map<KnobName, bool> _knob; + std::map<exo::Dialect, KnobName> _label = {{exo::Dialect::TFLITE, "TFL_"}, + {exo::Dialect::CIRCLE, "CIRCLE_"}}; +}; + +} // namespace + +namespace +{ + +EnvKnobLoader &knob_loader(void) +{ + // TODO separate "EXOTFLITE_" and "EXOCIRCLE_" when necessary + static EnvKnobLoader loader; + return loader; +} + +} // namespace + +namespace exo +{ + +#define KNOB_BOOL(NAME, TFL_DEFAULT, CIRCLE_DEFAULT, DESC) \ + template <> typename KnobTrait<Knob::NAME>::ValueType get<Knob::NAME>(void) \ + { \ + return ::knob_loader().knob_get(#NAME); \ + } +#include "Knob.lst" +#undef KNOB_BOOL + +void set(Dialect d) +{ + ::knob_loader().dialect_set(d); + switch (d) + { + case Dialect::TFLITE: +#define KNOB_BOOL(NAME, TFL_DEFAULT, CIRCLE_DEFAULT, DESC) \ + ::knob_loader().knob_set(#NAME, TFL_DEFAULT); +#include "Knob.lst" +#undef KNOB_BOOL + break; + case Dialect::CIRCLE: +#define KNOB_BOOL(NAME, TFL_DEFAULT, CIRCLE_DEFAULT, DESC) \ + ::knob_loader().knob_set(#NAME, CIRCLE_DEFAULT); +#include "Knob.lst" +#undef KNOB_BOOL + break; + default: + std::runtime_error("UnKnown dialect"); + } +} + +} // namespace exo diff --git a/compiler/exo/src/Knob.h b/compiler/exo/src/Knob.h new file mode 100644 index 000000000..98613120c --- /dev/null +++ b/compiler/exo/src/Knob.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __KNOB_H__ +#define __KNOB_H__ + +namespace exo +{ + +enum class Dialect +{ + TFLITE, + CIRCLE +}; + +enum class Knob +{ +#define KNOB_BOOL(NAME, TFL_DEFAULT, CIRCLE_DEFAULT, DESC) NAME, +#include "Knob.lst" +#undef KNOB_BOOL +}; + +template <Knob K> struct KnobTrait; + +#define KNOB_BOOL(NAME, TFL_DEFAULT, CIRCLE_DEFAULT, DESC) \ + template <> struct KnobTrait<Knob::NAME> \ + { \ + using ValueType = bool; \ + }; +#include "Knob.lst" +#undef KNOB_BOOL + +template <Knob K> typename KnobTrait<K>::ValueType get(void); +void set(Dialect); + +} // namespace exo + +#endif // __KNOB_H__ diff --git a/compiler/exo/src/Knob.lst b/compiler/exo/src/Knob.lst new file mode 100644 index 000000000..7f59c93f3 --- /dev/null +++ b/compiler/exo/src/Knob.lst @@ -0,0 +1,11 @@ +#ifndef KNOB_BOOL +#error "KNOB_BOOL is not defined" +#endif // KNOB_BOOL + +// KNOB_BOOL(KNOB_NAME, TFL_DEFAULT, CIRCLE_DEFAULT, DESCRIPTION) + +// Optimization pass +KNOB_BOOL(UseFuseBiasAddPass, true, true, Fuse TFLAdd or TFLSub into TFLConv2D) +KNOB_BOOL(UseFuseInstanceNormPass, false, true, Fuse InstanceNorm pattern) +KNOB_BOOL(UseFuseReluPass, true, true, Fuse TFLAdd or TFLSub into TFLConv2D or so) +KNOB_BOOL(UseFuseSquaredDifferencePass, false, true, Fuse SquaredDifference pattern) diff --git a/compiler/exo/src/Log.cpp b/compiler/exo/src/Log.cpp new file mode 100644 index 000000000..aa762968b --- /dev/null +++ b/compiler/exo/src/Log.cpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Log.h" + +#include <hermes/ConsoleReporter.h> +#include <stdex/Memory.h> + +#include <cstdlib> +#include <iostream> + +// TODO Extract these lexical conversion routines as a library +namespace +{ + +/** + * @brief Convert C-string as a value of type T + * + * safecast(s, v) returns v if s is nullptr. + */ +template <typename T> T safecast(const char *, const T &); + +template <> bool safecast<bool>(const char *s, const bool &value) +{ + return (s == nullptr) ? value : (std::stoi(s) != 0); +} + +} // namespace + +namespace exo +{ + +// +// Logger +// +Logger::Logger(hermes::Context *ctx) { activate(ctx->sources(), ctx->bus()); } +Logger::~Logger() { deactivate(); } + +// +// LoggerConfig +// +LoggerConfig::LoggerConfig() +{ + // Turn on logging if EXO_LOG is set as non-zero value + _enabled = safecast<bool>(std::getenv("EXO_LOG"), false); +} + +void LoggerConfig::configure(const hermes::Source *source, hermes::Source::Setting &setting) const +{ + // Let's ignore hermes::Sources if that is not a exo logger + if (auto logger = dynamic_cast<const Logger *>(source)) + { + configure(logger, setting); + } +} + +void LoggerConfig::configure(const Logger *, hermes::Source::Setting &setting) const +{ + if (_enabled) + { + // Enable all catagories + setting.accept_all(); + } + else + { + // Disable all catagories + setting.reject_all(); + } +} + +} // namespace exo diff --git a/compiler/exo/src/Log.h b/compiler/exo/src/Log.h new file mode 100644 index 000000000..8ca38c3ec --- /dev/null +++ b/compiler/exo/src/Log.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOG_H__ +#define __LOG_H__ + +#include "exo/LoggingContext.h" + +#include <hermes.h> + +namespace exo +{ + +/** + * @brief Logger Implementation + */ +class Logger final : public hermes::Source +{ +public: + Logger(hermes::Context *ctx); + ~Logger(); +}; + +/** + * @brief Logger Configuration + * + * Users are able to turn logging on/off via EXO_LOG environment variable. + */ +class LoggerConfig final : public hermes::Config +{ +public: + LoggerConfig(); + +public: + void configure(const hermes::Source *, hermes::Source::Setting &) const final; + void configure(const Logger *, hermes::Source::Setting &) const; + +private: + bool _enabled; +}; + +} // namespace exo + +/** + * HOW TO USE: + * + * LOGGER(l); + * + * INFO(l) << "Hello, World" << std::endl; + * + */ +#define LOGGER(name) ::exo::Logger name{::exo::LoggingContext::get()}; + +// TODO Support FATAL, ERROR, WARN, and VERBOSE +#define INFO(name) HERMES_INFO(name) + +// WARNING! +// +// THE CURRENT IMPLEMENTATION IS NOT THREAD SAFE. +// + +#endif // __LOG_H__ diff --git a/compiler/exo/src/LogHelper.cpp b/compiler/exo/src/LogHelper.cpp new file mode 100644 index 000000000..7520b7ec8 --- /dev/null +++ b/compiler/exo/src/LogHelper.cpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LogHelper.h" + +namespace loco +{ + +std::ostream &operator<<(std::ostream &os, const loco::FeatureShape &feature_shape) +{ + os << "[" << feature_shape.count().value() << "," << feature_shape.height().value() << "," + << feature_shape.width().value() << "," << feature_shape.depth().value() << "]"; + return os; +} + +std::ostream &operator<<(std::ostream &os, const loco::FilterShape &filter_shape) +{ + os << "[" << filter_shape.height().value() << "," << filter_shape.width().value() << "," + << filter_shape.depth().value() << "," << filter_shape.count().value() << "]"; + return os; +} + +std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape) +{ + os << "["; + for (uint32_t r = 0; r < tensor_shape.rank(); ++r) + { + if (r) + os << ","; + os << tensor_shape.dim(r).value(); + } + os << "]"; + return os; +} + +std::ostream &operator<<(std::ostream &os, const loco::Padding2D &pad) +{ + os << "[TLBR " << pad.top() << "," << pad.left() << "," << pad.bottom() << "," << pad.right() + << "]"; + + return os; +} + +} // namespace loco + +std::ostream &operator<<(std::ostream &os, const std::vector<int64_t> &vi64) +{ + for (auto vi : vi64) + { + os << vi << " "; + } + return os; +} + +#include "ExoFormattedGraph.h" + +namespace exo +{ + +FormattedGraph fmt(loco::Graph *g) +{ + auto node_summary_builder = stdex::make_unique<NodeSummaryBuilderFactory>(); + return std::move(locop::fmt<locop::LinearV1>(g).with(std::move(node_summary_builder))); +} + +} // namespace exo diff --git a/compiler/exo/src/LogHelper.h b/compiler/exo/src/LogHelper.h new file mode 100644 index 000000000..69d81af9e --- /dev/null +++ b/compiler/exo/src/LogHelper.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOG_HELPER_H__ +#define __LOG_HELPER_H__ + +#include <locop/FormattedGraph.h> + +#include <loco/IR/FeatureShape.h> +#include <loco/IR/FilterShape.h> +#include <loco/IR/TensorShape.h> + +#include <sstream> +#include <vector> + +namespace loco +{ + +/** + * @brief dump FeatureShape values to stream + */ +std::ostream &operator<<(std::ostream &os, const loco::FeatureShape &feature_shape); + +/** + * @brief dump FilterShape values to stream + */ +std::ostream &operator<<(std::ostream &os, const loco::FilterShape &filter_shape); + +/** + * @brief dump TensorShape values to stream + */ +std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape); + +/** + * @brief dump Padding2D values to stream + */ +std::ostream &operator<<(std::ostream &os, const loco::Padding2D &pad); + +} // namespace loco + +/** + * @brief dump std::vector<int64_t> values to stream + */ +std::ostream &operator<<(std::ostream &os, const std::vector<int64_t> &vi64); + +namespace exo +{ + +using FormattedGraph = locop::FormattedGraphImpl<locop::Formatter::LinearV1>; + +FormattedGraph fmt(loco::Graph *g); + +static inline FormattedGraph fmt(const std::unique_ptr<loco::Graph> &g) { return fmt(g.get()); } + +} // namespace exo + +#endif // __LOG_HELPER_H__ diff --git a/compiler/exo/src/LoggingContext.cpp b/compiler/exo/src/LoggingContext.cpp new file mode 100644 index 000000000..1c14d97b9 --- /dev/null +++ b/compiler/exo/src/LoggingContext.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "exo/LoggingContext.h" +#include "Log.h" // To use LoggerConfig + +#include <hermes/ConsoleReporter.h> +#include <stdex/Memory.h> + +namespace exo +{ + +hermes::Context *LoggingContext::get(void) +{ + static hermes::Context *ctx = nullptr; + + if (ctx == nullptr) + { + ctx = new hermes::Context; + ctx->sinks()->append(stdex::make_unique<hermes::ConsoleReporter>()); + ctx->config(stdex::make_unique<LoggerConfig>()); + } + + return ctx; +} + +} // namespac exo diff --git a/compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp b/compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp new file mode 100644 index 000000000..0fdcea939 --- /dev/null +++ b/compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FoldReshapeOfConstPass.h" + +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +#include <loco/Service/ShapeInference.h> + +#include <oops/InternalExn.h> + +namespace +{ + +/** + * @brief Check if node is TFLReshape and its input is TFLConst + * @return Casted TFLReshape for foldable candidate, nullptr otherwise + */ +locoex::TFLReshape *as_candidate(loco::Node *node) +{ + auto reshape = dynamic_cast<locoex::TFLReshape *>(node); + if (not reshape) + return nullptr; + + // Only accept Constant input of Reshape + if (not dynamic_cast<locoex::TFLConst *>(reshape->tensor())) + return nullptr; + + return reshape; +} + +uint32_t volume(loco::Node *tensor_node) +{ + auto shape = loco::shape_get(tensor_node).as<loco::TensorShape>(); + + uint32_t vol = 1; + for (uint32_t axis = 0; axis < shape.rank(); ++axis) + vol *= shape.dim(axis).value(); + + return vol; +} + +void fold_reshape_of_const(locoex::TFLReshape *reshape) +{ + const loco::DataType FLOAT32 = loco::DataType::FLOAT32; + + auto const_orig = dynamic_cast<locoex::TFLConst *>(reshape->tensor()); + + // Exceptions + { + EXO_ASSERT(const_orig, "Only support for Reshape-Const pair"); + // TODO support other data types + if (const_orig->dtype() != FLOAT32) + INTERNAL_EXN_V("NYI for this data type", oops::to_uint32(const_orig->dtype())); + + if (volume(const_orig) != volume(reshape)) + INTERNAL_EXN("New shape of Reshape is not matched"); + } + + auto new_shape = loco::shape_get(reshape).as<loco::TensorShape>(); + + // TFLConst to replace + auto const_new = reshape->graph()->nodes()->create<locoex::TFLConst>(); + + const_new->dtype(FLOAT32); + const_new->rank(new_shape.rank()); + const_new->size<FLOAT32>(const_orig->size<FLOAT32>()); + for (uint32_t axis = 0; axis < new_shape.rank(); ++axis) + const_new->dim(axis) = new_shape.dim(axis); + + for (uint32_t i = 0; i < const_new->size<FLOAT32>(); ++i) + { + const_new->at<FLOAT32>(i) = const_orig->at<FLOAT32>(i); + } + + // replace + loco::replace(reshape).with(const_new); +} + +} // namespace + +namespace exo +{ + +bool FoldReshapeOfConstPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto reshape = as_candidate(node)) + { + fold_reshape_of_const(reshape); + changed = true; + } + } + + return changed; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FoldReshapeOfConstPass.h b/compiler/exo/src/Pass/FoldReshapeOfConstPass.h new file mode 100644 index 000000000..10f8004bf --- /dev/null +++ b/compiler/exo/src/Pass/FoldReshapeOfConstPass.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_FOLD_RESHAPE_OF_CONST_PASS_H__ +#define __PASS_FOLD_RESHAPE_OF_CONST_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse TFLReshape + TFLConst into one equivalent TFLConst + * + * <before> + * TFLConst --- TFLReshape --- Out + * + * <after> + * TFLConst --- TFLReshape --- + * TFLConst (new) ------------ Out + * + * TODO This pass is for temporary. Deprecate this pass. + */ +struct FoldReshapeOfConstPass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FoldReshapeOfConstPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __PASS_FOLD_RESHAPE_OF_CONST_PASS_H__ diff --git a/compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp b/compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp new file mode 100644 index 000000000..005c42944 --- /dev/null +++ b/compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FoldTransposeOfConstPass.h" + +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +// TODO remove dependency to angkor +#include <nncc/core/ADT/tensor/IndexEnumerator.h> +#include <nncc/core/ADT/tensor/LexicalLayout.h> + +#include <oops/InternalExn.h> + +namespace +{ + +/** + * @brief Check if node is TFLTranspose and its input is TFLConst + * @return Casted TFLTranspose for foldable candidate, nullptr otherwise + */ +locoex::TFLTranspose *as_candidate(loco::Node *node) +{ + auto transpose = dynamic_cast<locoex::TFLTranspose *>(node); + if (not transpose) + return nullptr; + + // Only accept Constant input of Transpose + if (not dynamic_cast<locoex::TFLConst *>(transpose->a())) + return nullptr; + + // Only accept Constant permutation of Transpose + if (not dynamic_cast<locoex::TFLConst *>(transpose->perm())) + return nullptr; + + return transpose; +} + +nncc::core::ADT::tensor::Shape angkor_shape(locoex::TFLConst *node) +{ + nncc::core::ADT::tensor::Shape ret; + + ret.resize(node->rank()); + for (uint32_t axis = 0; axis < node->rank(); ++axis) + { + ret.dim(axis) = node->dim(axis).value(); + } + + return ret; +} + +void fold_transpose_of_const(locoex::TFLTranspose *transpose) +{ + const loco::DataType FLOAT32 = loco::DataType::FLOAT32; + const loco::DataType S32 = loco::DataType::S32; + + auto const_orig = dynamic_cast<locoex::TFLConst *>(transpose->a()); + auto perm = dynamic_cast<locoex::TFLConst *>(transpose->perm()); + + // Exceptions + { + EXO_ASSERT(const_orig, "Only support for Transpose-Const pair"); + // TODO support other data types + if (const_orig->dtype() != FLOAT32) + INTERNAL_EXN_V("NYI for this data type", oops::to_uint32(const_orig->dtype())); + + EXO_ASSERT(perm, "Only support for constant permutation for Transpose"); + // TODO support other data types + if (perm->dtype() != S32) + INTERNAL_EXN_V("NYI for this data type", oops::to_uint32(perm->dtype())); + + auto okay = [&]() { + if (perm->rank() != 1) + return false; + if (perm->dim(0).value() != const_orig->rank()) + return false; + return true; + }; + if (not okay()) + INTERNAL_EXN("Input and permutation for Transpose is not congruent"); + } + + uint32_t rank = const_orig->rank(); + + // TFLConst to replace + auto const_new = transpose->graph()->nodes()->create<locoex::TFLConst>(); + + const_new->dtype(FLOAT32); + const_new->rank(rank); + const_new->size<FLOAT32>(const_orig->size<FLOAT32>()); + for (uint32_t axis = 0; axis < rank; ++axis) + const_new->dim(axis) = const_orig->dim(perm->at<S32>(axis)).value(); + + // TODO remove dependency to angkor + auto shape_orig = angkor_shape(const_orig); + auto shape_new = angkor_shape(const_new); + + nncc::core::ADT::tensor::LexicalLayout l; + nncc::core::ADT::tensor::IndexEnumerator e{shape_new}; + + for (; e.valid(); e.advance()) + { + loco::TensorIndex index_new = e.current(); + loco::TensorIndex index_orig; + + // Set original index from matching new index + index_orig.resize(rank); + for (uint32_t axis = 0; axis < rank; ++axis) + index_orig.at(perm->at<S32>(axis)) = index_new.at(axis); + + const_new->at<FLOAT32>(l.offset(shape_new, index_new)) = + const_orig->at<FLOAT32>(l.offset(shape_orig, index_orig)); + } + + // replace + loco::replace(transpose).with(const_new); +} + +} // namespace + +namespace exo +{ + +bool FoldTransposeOfConstPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto transpose = as_candidate(node)) + { + fold_transpose_of_const(transpose); + changed = true; + } + } + + return changed; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FoldTransposeOfConstPass.h b/compiler/exo/src/Pass/FoldTransposeOfConstPass.h new file mode 100644 index 000000000..26656a118 --- /dev/null +++ b/compiler/exo/src/Pass/FoldTransposeOfConstPass.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_FOLD_TRANSPOSE_OF_CONST_PASS_H__ +#define __PASS_FOLD_TRANSPOSE_OF_CONST_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse TFLTranspose + TFLConst into one equivalent TFLConst + * + * <before> + * TFLConst --- TFLTranspose --- Out + * + * <after> + * TFLConst --- TFLTranspose --- + * TFLConst (new) -------------- Out + * + * TODO This pass is for temporary. Deprecate this pass. + */ +struct FoldTransposeOfConstPass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FoldTransposeOfConstPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __PASS_FOLD_TRANSPOSE_OF_CONST_PASS_H__ diff --git a/compiler/exo/src/Pass/FuseBiasAddPass.cpp b/compiler/exo/src/Pass/FuseBiasAddPass.cpp new file mode 100644 index 000000000..aab820995 --- /dev/null +++ b/compiler/exo/src/Pass/FuseBiasAddPass.cpp @@ -0,0 +1,362 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseBiasAddPass.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +#include <loco/Service/TypeInference.h> +#include <loco/Service/ShapeInference.h> + +#include <oops/InternalExn.h> + +#include <set> + +/* + Note: Terms for variables in this implementation is as follows: + + ex) subgraph handled: TFLConv2D -------- TFLAdd + (or TFLDepthwiseConv2D) (or TFLSub) + | | + \|/ \|/ + variable name : former latter + Type : FormerT LatterT + (shortened name from Mixin) (template type) +*/ +namespace +{ + +using FormerT = locoex::TFLNodeMixin<locoex::TFLNodeTrait::Bias>; + +loco::Node *as_loco_node(FormerT *former) +{ + auto loco_node = dynamic_cast<loco::Node *>(former); + assert(loco_node != nullptr); + + return loco_node; +} + +locoex::TFLConst *get_const(loco::Node *x, loco::Node *y) +{ + if (auto const_node = dynamic_cast<locoex::TFLConst *>(x)) + return const_node; + else if (auto const_node = dynamic_cast<locoex::TFLConst *>(y)) + return const_node; + + return nullptr; +} + +FormerT *get_former(loco::Node *x, loco::Node *y) +{ + if (auto node = dynamic_cast<FormerT *>(x)) + return node; + else if (auto node = dynamic_cast<FormerT *>(y)) + return node; + + return nullptr; +} + +/// @brief Finds input that is TFLConst and set it to new_input +void set_const_input(locoex::TFLNode *node, locoex::TFLConst *new_input) +{ + if (auto add = dynamic_cast<locoex::TFLAdd *>(node)) + { + if (dynamic_cast<locoex::TFLConst *>(add->x())) + add->x(new_input); + else if (dynamic_cast<locoex::TFLConst *>(add->y())) + add->y(new_input); + else + assert(false and "One node should be TFLConst"); + + return; + } + + if (auto sub = dynamic_cast<locoex::TFLSub *>(node)) + { + if (dynamic_cast<locoex::TFLConst *>(sub->x())) + sub->x(new_input); + else if (dynamic_cast<locoex::TFLConst *>(sub->y())) + sub->y(new_input); + else + assert(false and "One node should be TFLConst"); + + return; + } + + assert(false and "Param should be TFLAdd or TFLSub"); +} + +/** + * @brief Creates a TFLConst whose shape is [to] and values are all const_node->at(0), + * where const_node has only one element(a scalar or a tensor of shape [1]) + */ +locoex::TFLConst *create_widened(locoex::TFLConst *const_node, uint32_t to) +{ + auto const_shape = loco::shape_get(const_node).as<loco::TensorShape>(); + + assert(const_shape.rank() == 0 or (const_shape.rank() == 1 and const_shape.dim(0) == 1)); + + auto g = const_node->graph(); + + auto widened_const = g->nodes()->create<locoex::TFLConst>(); + { + widened_const->dtype(loco::DataType::FLOAT32); + widened_const->rank(1); + widened_const->dim(0) = to; + widened_const->size<loco::DataType::FLOAT32>(to); + for (uint32_t x = 0; x < to; x++) + widened_const->at<loco::DataType::FLOAT32>(x) = const_node->at<loco::DataType::FLOAT32>(0); + } + return widened_const; +} + +template <typename TFLType> float calc(float, float); + +template <> float calc<locoex::TFLAdd>(float x, float y) { return x + y; } +template <> float calc<locoex::TFLSub>(float x, float y) { return x - y; } + +template <class LatterT> class Fuser +{ +public: + Fuser(LatterT *latter) + { + static_assert(std::is_same<LatterT, locoex::TFLAdd>::value || + std::is_same<LatterT, locoex::TFLSub>::value, + "wrong template type"); + + _latter = latter; + _graph = _latter->graph(); + _const_node = get_const(_latter->x(), _latter->y()); + _former = get_former(_latter->x(), _latter->y()); + + assert(_const_node && _former); + } + + void fuse(void); + +private: + loco::Graph *_graph; + LatterT *_latter; + locoex::TFLConst *_const_node; + FormerT *_former; + + locoex::TFLConst *create_fused_bias_const(); +}; + +// instantiation +template class Fuser<locoex::TFLAdd>; +template class Fuser<locoex::TFLSub>; + +template <class LatterT> locoex::TFLConst *Fuser<LatterT>::create_fused_bias_const() +{ + // we have to create a new bias const by adding/substracting bias and const node (of TFLAdd or + // TFLSub) + auto bias = dynamic_cast<locoex::TFLConst *>(_former->bias()); + assert(bias->dtype() == loco::DataType::FLOAT32 && + _const_node->dtype() == loco::DataType::FLOAT32); + + assert(bias->rank() == 1 && _const_node->rank() == 1); + assert(bias->dim(0) == _const_node->dim(0)); + + // build a new bias const + auto new_bias = _graph->nodes()->create<locoex::TFLConst>(); + { + new_bias->dtype(loco::DataType::FLOAT32); + + new_bias->rank(1); + new_bias->dim(0) = bias->dim(0); + + new_bias->size<loco::DataType::FLOAT32>(bias->dim(0).value()); + + for (uint32_t x = 0; x < bias->dim(0).value(); x++) + new_bias->at<loco::DataType::FLOAT32>(x) = calc<LatterT>( + bias->at<loco::DataType::FLOAT32>(x), _const_node->at<loco::DataType::FLOAT32>(x)); + } + + return new_bias; +} + +// FuseBiasAddPass works when former->fusedActivationFunction() == NONE +bool check_act_func(FormerT *former) +{ + using FusedActFuncMixin = locoex::TFLNodeMixin<locoex::TFLNodeTrait::FusedActFunc>; + + if (auto node = dynamic_cast<FusedActFuncMixin *>(former)) + return node->fusedActivationFunction() == locoex::FusedActFunc::NONE; + else + return true; +} + +template <class LatterT> void set_act_func(FormerT *former, LatterT *latter) +{ + using FusedActFuncMixin = locoex::TFLNodeMixin<locoex::TFLNodeTrait::FusedActFunc>; + + if (auto node = dynamic_cast<FusedActFuncMixin *>(former)) + node->fusedActivationFunction(latter->fusedActivationFunction()); +} + +// instantiation +template void set_act_func(FormerT *, locoex::TFLAdd *); +template void set_act_func(FormerT *, locoex::TFLSub *); + +/** + * @brief Fuse TFLAdd or TFLSub (latter) into TFLConv2d or TFLDepthwiseConv2D (former). + * All conditions should be checked before calling this. + * + * @note TFLAdd can have fused activation function (let's call this FAF for simplicity). + * + * Conv2D's FAF | TFLAdd's FAF => FAF after fusing TFLAdd into TFLConv2D + * ----------------|--------------- -------------------------------------- + * NONE | NONE, RELU or RELU6 => TFLAdd's FAF + * other than NONE | anything => cannot be fused + */ +template <class LatterT> void Fuser<LatterT>::fuse(void) +{ + // check fused activation function + { + assert(check_act_func(_former)); + + set_act_func<LatterT>(_former, _latter); + } + + auto new_bias = create_fused_bias_const(); + + // replace node with new_bias + // note that loco::replace() is not used because bias could be input of other op just in case + _former->bias(new_bias); + + // remove TFLAdd or TFLSub node + loco::replace(_latter).with(as_loco_node(_former)); + _latter->x(nullptr); + _latter->y(nullptr); +} + +struct Collector final : public locoex::TFLNodeMutableVisitor<void> +{ + template <class LatterT> + void setCandidate(FormerT *former, LatterT *latter, locoex::TFLConst *const_node) + { + static_assert(std::is_same<LatterT, locoex::TFLAdd>::value || + std::is_same<LatterT, locoex::TFLSub>::value, + "wrong template type"); + + if (!check_act_func(former)) + return; + + auto depth = + loco::shape_get(as_loco_node(former)).template as<loco::TensorShape>().dim(3).value(); + auto const_shape = loco::shape_get(const_node).template as<loco::TensorShape>(); + + if (const_shape.rank() == 1 and const_shape.dim(0) == depth) + { + candidates.insert(latter); + } + // when Const has only one value, create a new const with shape [depth] + else if (const_shape.rank() == 0 or (const_shape.rank() == 1 and const_shape.dim(0) == 1)) + { + if (!(loco::dtype_get(as_loco_node(former)) == loco::DataType::FLOAT32)) + INTERNAL_EXN_V("Unsupported data type", + oops::to_uint32(loco::dtype_get(as_loco_node(former)))); + if (!(const_node->dtype() == loco::DataType::FLOAT32)) + INTERNAL_EXN_V("Unsupported data type", oops::to_uint32(const_node->dtype())); + + auto new_bias_node = create_widened(const_node, depth); + + // Replacing TFLConst input of TFLAdd or TFLSub. + // Note that calling loco::replace(const_node).with(new_bias_node) could be dangerous + // because const_node could be the input of many nodes + set_const_input(latter, new_bias_node); + + candidates.insert(latter); + } + } + + void visit(locoex::TFLAdd *latter) final + { + auto former = get_former(latter->x(), latter->y()); + auto const_node = get_const(latter->x(), latter->y()); + + if (former && const_node) + setCandidate<locoex::TFLAdd>(former, latter, const_node); + } + + void visit(locoex::TFLSub *latter) final + { + // TFLSub, of which x() = TFLConv2D or TFLDepthwiseConv2D, y() = TFLConst, is fusing target + auto former = dynamic_cast<FormerT *>(latter->x()); + auto const_node = dynamic_cast<locoex::TFLConst *>(latter->y()); + + if (former && const_node) + setCandidate<locoex::TFLSub>(former, latter, const_node); + } + + void visit(locoex::TFLNode *) final { return; } + + std::set<locoex::TFLNode *> candidates; +}; + +struct Performer final : public locoex::TFLNodeMutableVisitor<void> +{ + void visit(locoex::TFLAdd *latter) final + { + assert(get_former(latter->x(), latter->y())); + + Fuser<locoex::TFLAdd> fuser(latter); + fuser.fuse(); + } + + void visit(locoex::TFLSub *latter) final + { + assert(get_former(latter->x(), latter->y())); + + Fuser<locoex::TFLSub> fuser(latter); + fuser.fuse(); + } + + void visit(locoex::TFLNode *) final { assert(false && "should not be called"); } +}; + +} // namespace + +namespace exo +{ + +bool FuseBiasAddPass::run(loco::Graph *g) +{ + Collector collector; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (node->dialect() == locoex::TFLDialect::get()) + { + auto tfl_node = dynamic_cast<locoex::TFLNode *>(node); + tfl_node->accept(&collector); + } + } + + Performer performer; + + for (auto node : collector.candidates) + { + node->accept(&performer); + } + + return collector.candidates.size() > 0; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FuseBiasAddPass.h b/compiler/exo/src/Pass/FuseBiasAddPass.h new file mode 100644 index 000000000..68e624c6b --- /dev/null +++ b/compiler/exo/src/Pass/FuseBiasAddPass.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_FUSE_BIASADD_PASS_H__ +#define __PASS_FUSE_BIASADD_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse TFLAdd or TFLSub into Bias input of the following ops: + * - TFLConv2D, TFLDepthwiseConv2D + * - TODO Consider to add FullyConnected, etc. + * + * Case 1. Conv2D and TFLAdd + * + * BEFORE: + * + * TFLConst A (a scalar or a tensor of shape [1] or [depth of TFLConv2D]) + * | + * Foo -- TFLConv2D -- TFLAdd (or TFLSub) -- Bar + * | + * TFLConst B --+ (bias) + * + * AFTER: + * Foo ----- TFLConv2D ----- Bar + * | + * TFLConst A' --+ (bias) + * + * TFLConst B (dead node) + * + * TFLAdd (or TFLSub) (dead node) + * + * @note TFLSub, of which x() == TFLConv2D and y() == TFLConst, will be fused. + * If x() == TFLConst and y() == TFLConv2D, it won't be fused. + */ +struct FuseBiasAddPass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FuseBiasAddPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __PASS_FUSE_BIASADD_PASS_H__ diff --git a/compiler/exo/src/Pass/FuseBiasAddPass.test.cpp b/compiler/exo/src/Pass/FuseBiasAddPass.test.cpp new file mode 100644 index 000000000..6ba728de0 --- /dev/null +++ b/compiler/exo/src/Pass/FuseBiasAddPass.test.cpp @@ -0,0 +1,361 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseBiasAddPass.h" + +#include "Dialect/IR/TFLNodes.h" +#include "TestGraph.h" +#include "TestHelper.h" + +#include <loco.h> + +#include <gtest/gtest.h> + +namespace +{ + +void init(loco::Pull *pull) +{ + pull->dtype(loco::DataType::FLOAT32); + pull->shape({2, 3, 3, 2}); +} + +/// @brief Initializes TFLConv2D and related filter and bias +void init(locoex::TFLConv2D *conv2d, locoex::TFLConst *filter, locoex::TFLConst *bias) +{ + // set conv2d + { + conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE); + conv2d->padding(locoex::Padding::VALID); + } + + // set filter + { + filter->dtype(loco::DataType::FLOAT32); + filter->shape({2, 3, 3, 2}); + filter->size<loco::DataType::FLOAT32>(2 * 3 * 3 * 2); + + for (uint32_t x = 0; x < 2 * 3 * 3 * 2; x++) + filter->at<loco::DataType::FLOAT32>(x) = 0.0; + } + + // set bias + { + bias->dtype(loco::DataType::FLOAT32); + bias->shape({2}); + bias->size<loco::DataType::FLOAT32>(2); + + for (uint32_t x = 0; x < 2; x++) + bias->at<loco::DataType::FLOAT32>(x) = 0.0; + } +} + +template <class T> void init(T *node, locoex::FusedActFunc f) +{ + static_assert(std::is_same<T, locoex::TFLAdd>::value || std::is_same<T, locoex::TFLSub>::value, + "wrong template type"); + + node->fusedActivationFunction(f); +} + +/// @brief Initializes one param of TFLAdd or TFLSub +void init(locoex::TFLConst *addsub_param) +{ + // set addsub_param : y() value of TFLAdd or TFLSub + addsub_param->dtype(loco::DataType::FLOAT32); + addsub_param->shape({2}); + addsub_param->size<loco::DataType::FLOAT32>(2); + + for (uint32_t x = 0; x < 2; x++) + addsub_param->at<loco::DataType::FLOAT32>(x) = (x + 1) * 1.5; // 1.5, 3 +} + +} // namespace + +// A case when +// - TFLConv2D has bias (0, 0) +// - TFLAdd, of which x() or y() == TFLConv2D +// - Another param of TFLAdd is TFLConst, (1.5, 3) +// +// After fusion, bias shold be (1.5, 3) +TEST(FuseBiasAddPassTest, Conv2D_Add_01_basic) +{ + exo::test::TestGraph g; + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto add_y = g.append<locoex::TFLConst>(); + auto add = g.append<locoex::TFLAdd>(conv2d, add_y); + + g.complete(add); + + init(g.pull); + init(conv2d, filter, bias); + init(add, locoex::FusedActFunc::NONE); + init(add_y); + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseBiasAddPass>(); + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + + auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias()); + ASSERT_TRUE(a_bias != nullptr); + + ASSERT_TRUE(a_bias->dim(0) == 2); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0), + bias->at<loco::DataType::FLOAT32>(0) + add_y->at<loco::DataType::FLOAT32>(0)); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1), + bias->at<loco::DataType::FLOAT32>(1) + add_y->at<loco::DataType::FLOAT32>(1)); +} + +// A case when +// - TFLConv2D has bias (0, 0) +// - TFLAdd, of which x() or y() == TFLConv2D +// - Another param of TFLAdd is TFLConst, (1.5) <-- scalar +// +// After fusion, bias shold be (1.5, 1.5) +TEST(FuseBiasAddPassTest, Conv2D_Add_02_TFLAdd_y_is_scalar) +{ + exo::test::TestGraph g; + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto add_y = g.append<locoex::TFLConst>(); + auto add = g.append<locoex::TFLAdd>(conv2d, add_y); + + g.complete(add); + + init(g.pull); + init(conv2d, filter, bias); // channel of conv2d is 2 + + { + // Size of this TFLConst is 1. + // Note that this should be widened later to the shape of [channel of Conv2D], which is [2] + add_y->dtype(loco::DataType::FLOAT32); + add_y->shape({1}); + add_y->size<loco::DataType::FLOAT32>(1); + add_y->at<loco::DataType::FLOAT32>(0) = 1.5; + } + init(add, locoex::FusedActFunc::NONE); + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseBiasAddPass>(); + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + + auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias()); + ASSERT_TRUE(a_bias != nullptr); + + ASSERT_TRUE(a_bias->dim(0) == 2); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0), + bias->at<loco::DataType::FLOAT32>(0) + 1.5); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1), + bias->at<loco::DataType::FLOAT32>(1) + 1.5); +} + +// A case when +// - TFLConv2D has bias (0, 0) +// - TFLSub.x() == TFLConv2D +// - TFLSub.y() == TFLConst, (1.5, 3) +// +// After fusion, bias shold be (-1.5, -3) +TEST(FuseBiasAddPassTest, Conv2D_Sub_01_basic) +{ + exo::test::TestGraph g; + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto sub_y = g.append<locoex::TFLConst>(); + auto sub = g.append<locoex::TFLSub>(conv2d, sub_y); + + g.complete(sub); + + init(g.pull); + init(conv2d, filter, bias); + init(sub, locoex::FusedActFunc::NONE); + init(sub_y); + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseBiasAddPass>(); + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + + auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias()); + ASSERT_TRUE(a_bias != nullptr); + + ASSERT_TRUE(a_bias->dim(0) == 2); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0), + bias->at<loco::DataType::FLOAT32>(0) - sub_y->at<loco::DataType::FLOAT32>(0)); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1), + bias->at<loco::DataType::FLOAT32>(1) - sub_y->at<loco::DataType::FLOAT32>(1)); +} + +// A case when TFLConv2D is input of TFLSub but fusion cannot be performed. +// - TFLSub.x() == TFLConst +// - TFLSub.y() == TFLConv2D +// +// Here, TFLSub cannot be fused into TFLConst. To be fused, TFLSub.x() should be TFLConv2D and +// TFLSub.y() should be TFLConst. So fusion will NOT happen. +TEST(FuseBiasAddPassTest, Conv2D_Sub_02_fusing_will_not_performed) +{ + exo::test::TestGraph g; + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto sub_y = g.append<locoex::TFLConst>(); + auto sub = g.append<locoex::TFLSub>(sub_y, conv2d); // This WON'T be fused + + g.complete(sub); + + init(g.pull); + init(conv2d, filter, bias); + init(sub, locoex::FusedActFunc::NONE); + init(sub_y); + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseBiasAddPass>(); + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + + auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias()); + ASSERT_TRUE(a_bias != nullptr); + + ASSERT_TRUE(a_bias->dim(0) == 2); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0), 0); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1), 0); + + auto a_sub = exo::test::find_first_node_bytype<locoex::TFLSub>(g.graph()); + ASSERT_TRUE(a_sub != nullptr); + ASSERT_TRUE(a_sub->y() == a_conv2d); // Checking 'not-fused' state +} + +// A case when +// - TFLConv2D has an activation function with Relu +// - TFLAdd, has no activation function +// +// No fusion should happen +TEST(FuseBiasAddPassTest, Regression_Conv2D_Add_fused_action_00) +{ + exo::test::TestGraph g; + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto add_y = g.append<locoex::TFLConst>(); + auto add = g.append<locoex::TFLAdd>(conv2d, add_y); + + g.complete(add); + + init(g.pull); + init(conv2d, filter, bias); + init(add, locoex::FusedActFunc::NONE); + init(add_y); + + // Updating Fused Activation for this test + conv2d->fusedActivationFunction(locoex::FusedActFunc::RELU); + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseBiasAddPass>(); + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + ASSERT_TRUE(a_conv2d->fusedActivationFunction() == locoex::FusedActFunc::RELU); + + auto an_add = exo::test::find_first_node_bytype<locoex::TFLAdd>(g.graph()); + ASSERT_TRUE(an_add != nullptr); + ASSERT_TRUE(an_add->fusedActivationFunction() == locoex::FusedActFunc::NONE); + + ASSERT_TRUE(an_add->x() == a_conv2d or an_add->y() == a_conv2d); +} + +// A case when +// - TFLConv2D has NONE activation function +// - TFLAdd has Relu activation function +// +// TFLConv2D should have Relu activation function, TFLAdd is fused into bias input +TEST(FuseBiasAddPassTest, Regression_Conv2D_Add_fused_action_01) +{ + exo::test::TestGraph g; + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto add_y = g.append<locoex::TFLConst>(); + auto add = g.append<locoex::TFLAdd>(conv2d, add_y); + + g.complete(add); + + init(g.pull); + init(conv2d, filter, bias); + init(add, locoex::FusedActFunc::RELU); + init(add_y); + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseBiasAddPass>(); + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + + auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias()); + ASSERT_TRUE(a_bias != nullptr); + + ASSERT_TRUE(a_bias->dim(0) == 2); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0), + bias->at<loco::DataType::FLOAT32>(0) + add_y->at<loco::DataType::FLOAT32>(0)); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1), + bias->at<loco::DataType::FLOAT32>(1) + add_y->at<loco::DataType::FLOAT32>(1)); + + ASSERT_TRUE(a_conv2d->fusedActivationFunction() == locoex::FusedActFunc::RELU); +} diff --git a/compiler/exo/src/Pass/FuseInstanceNormPass.cpp b/compiler/exo/src/Pass/FuseInstanceNormPass.cpp new file mode 100644 index 000000000..04d4a62cd --- /dev/null +++ b/compiler/exo/src/Pass/FuseInstanceNormPass.cpp @@ -0,0 +1,402 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseInstanceNormPass.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/CircleNodes.h" + +#include <loco/Service/ShapeInference.h> + +#include <cassert> +#include <set> + +// Helper to find commutative node's arguments +namespace +{ + +/** + * INTRODUCTION + * Binary operation f(x,y) is 'commutative' when + * f(x,y) == f(y,x) holds for all x, y. + * For examples, ADD, MUL and SQUARED_DIFFERENCE are commutative. + * These helpers make it easy to find commutative arguemnts of commtative node. + * + * HOW TO USE + * COMM_NODE *node; + * ARG_TYPE_1 *arg1; + * ARG_TYPE_2 *arg2; + * + * bool ok = fill(&arg1, &arg2).with_commutative_args_of(node); + * + * Result + * If 'node's commutative argument types are actually {ARG_TYPE_1, ARG_TYPE_2} + * (as a set), 'arg1' and 'arg2' set as actual 'node's arguemnts with matching + * type, and return value 'ok' is true. + * Otherwise, 'arg1' and 'arg2' not changed, 'ok' is false. + */ + +template <class ARG_TYPE_1, class ARG_TYPE_2> class NodeFiller final +{ +public: + NodeFiller(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) : _arg_1(arg_1), _arg_2(arg_2) + { + // DO NOTHING + } + + /** + * @return true When 'node's argument types are 'ARG_TYPE_1' and 'ARG_TYPE_2' + * In such case, it assign '_arg_1' and '_arg_2' to actual arguments + * + * @return false When 'node's argument types are NOT matched with 'ARG_TYPE_*' + * In such case, it does not amend '_arg_1' and '_arg_2' + * + * @require COMM_NODE has member x() and y() + */ + template <class COMM_NODE> bool with_commutative_args_of(const COMM_NODE *node); + +private: + ARG_TYPE_1 **_arg_1; + ARG_TYPE_2 **_arg_2; +}; + +template <class ARG_TYPE_1, class ARG_TYPE_2> +inline NodeFiller<ARG_TYPE_1, ARG_TYPE_2> fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) +{ + return NodeFiller<ARG_TYPE_1, ARG_TYPE_2>{arg_1, arg_2}; +} + +template <class ARG_TYPE_1, class ARG_TYPE_2> +template <class COMM_NODE> +bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(const COMM_NODE *node) +{ + // Case 1) X == ARG_TYPE_1 / Y == ARG_TYPE_2 + { + auto x = dynamic_cast<ARG_TYPE_1 *>(node->x()); + auto y = dynamic_cast<ARG_TYPE_2 *>(node->y()); + + if (x && y) + { + *_arg_1 = x; + *_arg_2 = y; + return true; + } + } + + // Case 2) X == ARG_TYPE_2 / Y == ARG_TYPE_1 + { + auto x = dynamic_cast<ARG_TYPE_2 *>(node->x()); + auto y = dynamic_cast<ARG_TYPE_1 *>(node->y()); + + if (x && y) + { + *_arg_1 = y; + *_arg_2 = x; + return true; + } + } + + return false; +} + +} // namespace + +// Helper to check detail +namespace +{ + +/// @return true When node has shape of '1 x .. x 1 x depth' +bool is_1D_with_dummy_dim(locoex::TFLConst *node, uint32_t depth) +{ + auto rank = node->rank(); + uint32_t axis; + for (axis = 0; axis < rank - 1; ++axis) + { + if (node->dim(axis).value() != 1) + return false; + } + return node->dim(axis).value() == depth; +} + +bool is_instance_mean(locoex::TFLMean *mean) +{ + // + // CHECK 1) input is rank 4 + // + auto input = mean->input(); + if (not loco::shape_known(input)) + return false; + auto input_shape = loco::shape_get(input).as<loco::TensorShape>(); + if (input_shape.rank() != 4) + return false; + + // + // CHECK 2) 'reduction indices' is TFLConst of value [1,2], that is HW of NHWC + // + // TODO Support equivalent case, like [-3,-2] + // TODO Support non-Const case? + // TODO What if input is NCHW format in Circle? + auto red_indices = dynamic_cast<locoex::TFLConst *>(mean->reduction_indices()); + if (not red_indices) + return false; + if (red_indices->rank() != 1) + return false; + std::set<int32_t> red_indices_set; + { + // TODO Currently only support S32, support other types + assert(red_indices->dtype() == loco::DataType::S32); + for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i) + red_indices_set.insert(red_indices->at<loco::DataType::S32>(i)); + } + if (red_indices_set.size() != 2) + return false; + if (red_indices_set.find(1) == red_indices_set.end()) + return false; + if (red_indices_set.find(2) == red_indices_set.end()) + return false; + + // + // CHECK 3) keep_dims == true (?) + // + // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false' + // TODO Check this fact, and if true, return true regardless of keep_dims + return mean->keep_dims(); +} + +} // namespace + +// Helper to fuse Instance Norm +namespace +{ + +/** + * SUBGRAPH PATTERN + * + * - Below diagram shows Instance Norm pattern to fuse. + * - Execution dependency order is top to the bottom. + * - Node name is matched with variable name of InstanceNormPattern class. + * - Usually, first word of node name (variable name) is node type. For e.g. + * variable 'mean_as_variance' is pointer to TFLMean. + * - (Item in parenthesis) means actually exist, but not having a name and + * not a variable of InstanceNormPattern class. + * + * TODO support other semantically same patterns for instance norm + * + * [In] + * | + * V + * +----------- ifm -----+ (reduction indicies) + * | | | | + * | | V V + * | | mean_of_ifm ----------------+ + * | V | | + * | sqdiff <--+ (reduction indicies) | + * | | | | + * | V | | + * | mean_as_variance <---+ const_as_epsilon | + * | | | | + * | V | | + * | add_as_variance <--------+ | + * | | | + * | V | + * | rsqrt const_as_gamma | + * | | | | + * | V | | + * | mul_gamma <--+ | + * | | | | + * V V V | + * mul_as_scaled_ifm mul_as_scaled_mean <-------------+ + * | | + * | const_as_beta | + * | | V + * | +------> sub + * V | + * add_as_terminal <----------+ + * | + * V + * [Out] + */ +class InstanceNormPattern final +{ +public: + InstanceNormPattern(locoex::TFLAdd *candidate) + { + assert(candidate); + add_as_terminal = candidate; + } + +public: + bool matched(); + bool matched() const { return _matched; } + +public: + // Context + loco::Node *ifm = nullptr; + locoex::TFLMean *mean_of_ifm = nullptr; + locoex::TFLSquaredDifference *sqdiff = nullptr; + locoex::TFLMean *mean_as_variance = nullptr; + locoex::TFLConst *const_as_epsilon = nullptr; + locoex::TFLAdd *add_as_variance = nullptr; + locoex::TFLRsqrt *rsqrt = nullptr; + locoex::TFLConst *const_as_gamma = nullptr; + locoex::TFLMul *mul_gamma = nullptr; + locoex::TFLMul *mul_as_scaled_ifm = nullptr; + locoex::TFLMul *mul_as_scaled_mean = nullptr; + locoex::TFLConst *const_as_beta = nullptr; + locoex::TFLSub *sub = nullptr; + locoex::TFLAdd *add_as_terminal = nullptr; + +private: + bool _matched = false; +}; + +bool InstanceNormPattern::matched() +{ + if (_matched) + return true; + +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; + + // Check order is DFS + + CHECK_OR_FALSE(fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal)); + CHECK_OR_FALSE(fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm)); + + CHECK_OR_FALSE(loco::shape_known(ifm)); + auto ifm_shape = loco::shape_get(ifm); + CHECK_OR_FALSE(ifm_shape.domain() == loco::Domain::Tensor); + auto ifm_tensor_shape = ifm_shape.as<loco::TensorShape>(); + CHECK_OR_FALSE(ifm_tensor_shape.rank() == 4); + uint32_t ifm_channel_depth = ifm_tensor_shape.dim(3).value(); + + CHECK_OR_FALSE(fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma)); + CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth)); + + add_as_variance = dynamic_cast<locoex::TFLAdd *>(rsqrt->x()); + CHECK_OR_FALSE(add_as_variance); + + CHECK_OR_FALSE( + fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance)); + + CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32); + // TODO Support regarding broadcast + CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1); + + CHECK_OR_FALSE(is_instance_mean(mean_as_variance)); + sqdiff = dynamic_cast<locoex::TFLSquaredDifference *>(mean_as_variance->input()); + CHECK_OR_FALSE(sqdiff); + + loco::Node *ifm_should_be = nullptr; + CHECK_OR_FALSE(fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff)); + CHECK_OR_FALSE(ifm == ifm_should_be); + CHECK_OR_FALSE(is_instance_mean(mean_of_ifm)); + CHECK_OR_FALSE(ifm == mean_of_ifm->input()); + + const_as_beta = dynamic_cast<locoex::TFLConst *>(sub->x()); + CHECK_OR_FALSE(const_as_beta); + CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth)); + + mul_as_scaled_mean = dynamic_cast<locoex::TFLMul *>(sub->y()); + CHECK_OR_FALSE(mul_as_scaled_mean); + + locoex::TFLMul *mul_gamma_should_be = nullptr; + locoex::TFLMean *mean_of_ifm_should_be = nullptr; + CHECK_OR_FALSE(fill(&mul_gamma_should_be, &mean_of_ifm_should_be) + .with_commutative_args_of(mul_as_scaled_mean)); + CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be); + CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be); +#undef CHECK_OR_FALSE + _matched = true; + return true; +} + +/** + * Instance norm pattern would be fused like following diagram: + * + * [In] --------------------------- CircleInstanceNorm --- [Out] + * / / + * const_as_gamma --- TFLReshape --- / + * / + * const_as_beta ---- TFLReshape --- + * + * Note + * - 'const_as_gamma' and 'const_as_beta' are from original graph + * - Value of 'const_as_epsilon' would be copied to CircleInstanceNorm's attribute + * - TFLReshape is added as CircleInstanceNorm only accept 1D tensor + * - 'TFLConst --- TFLReshape' is expected to be fused in constant folding for Reshape + */ +void fuse_instance_norm(const InstanceNormPattern &p) +{ + assert(p.matched()); + + auto graph = p.add_as_terminal->graph(); + + // Make reshape for gamma & beta + auto reshape_gamma = graph->nodes()->create<locoex::TFLReshape>(); + auto reshape_beta = graph->nodes()->create<locoex::TFLReshape>(); + { + auto ifm_shape = loco::shape_get(p.ifm).as<loco::TensorShape>(); + uint32_t ifm_channel_depth = ifm_shape.dim(3).value(); + + int32_t new_shape[1] = {static_cast<int32_t>(ifm_channel_depth)}; + + reshape_gamma->tensor(p.const_as_gamma); + reshape_beta->tensor(p.const_as_beta); + + locoex::set_new_shape(reshape_gamma, new_shape, 1); + locoex::set_new_shape(reshape_beta, new_shape, 1); + } + + // Make Instance Norm to replace + auto instance_norm = graph->nodes()->create<locoex::CircleInstanceNorm>(); + instance_norm->input(p.ifm); + instance_norm->gamma(reshape_gamma); + instance_norm->beta(reshape_beta); + float epsilon = p.const_as_epsilon->at<loco::DataType::FLOAT32>(0); + instance_norm->epsilon(epsilon); + instance_norm->fusedActivationFunction(p.add_as_terminal->fusedActivationFunction()); + + replace(p.add_as_terminal).with(instance_norm); +} + +} // namespace + +namespace exo +{ + +bool FuseInstanceNormPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto add = dynamic_cast<locoex::TFLAdd *>(node); + if (not add) + continue; + + InstanceNormPattern pattern(add); + if (not pattern.matched()) + continue; + + fuse_instance_norm(pattern); + changed = true; + } + + return changed; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FuseInstanceNormPass.h b/compiler/exo/src/Pass/FuseInstanceNormPass.h new file mode 100644 index 000000000..e6361021c --- /dev/null +++ b/compiler/exo/src/Pass/FuseInstanceNormPass.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSE_INSTANCE_NORM_PASS_H__ +#define __FUSE_INSTANCE_NORM_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse certain pattern of subgraph into CircleInstanceNorm + * with auxiliary nodes + * + * For detailed subgraph pattern to be fused, please check its implementation. + */ +struct FuseInstanceNormPass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FuseInstanceNormPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __FUSE_INSTANCE_NORM_PASS_H__ diff --git a/compiler/exo/src/Pass/FuseReluPass.cpp b/compiler/exo/src/Pass/FuseReluPass.cpp new file mode 100644 index 000000000..d7af0c506 --- /dev/null +++ b/compiler/exo/src/Pass/FuseReluPass.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseReluPass.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +#include <set> + +namespace +{ + +bool is_pred_fusable(loco::Node *node) +{ + using namespace locoex; + + auto fusable_node = dynamic_cast<TFLNodeMixin<TFLNodeTrait::FusedActFunc> *>(node); + + return (fusable_node and fusable_node->fusedActivationFunction() == FusedActFunc::NONE); +}; + +struct Collector final : public locoex::TFLNodeMutableVisitor<void> +{ + void visit(locoex::TFLRelu *node) final + { + if (is_pred_fusable(node->features())) + candidates.insert(node); + } + + void visit(locoex::TFLRelu6 *node) final + { + if (is_pred_fusable(node->features())) + candidates.insert(node); + } + + void visit(locoex::TFLNode *) final { return; } + + std::set<locoex::TFLNode *> candidates; +}; + +void set_activation_fusion(loco::Node *node, locoex::FusedActFunc f) +{ + using namespace locoex; + + if (auto fusable_node = dynamic_cast<TFLNodeMixin<TFLNodeTrait::FusedActFunc> *>(node)) + fusable_node->fusedActivationFunction(f); + else + assert(false); +} + +struct Performer final : public locoex::TFLNodeMutableVisitor<void> +{ + void visit(locoex::TFLRelu *the_relu) final + { + set_activation_fusion(the_relu->features(), locoex::FusedActFunc::RELU); + + loco::replace(the_relu).with(the_relu->features()); + the_relu->features(nullptr); + } + + void visit(locoex::TFLRelu6 *the_relu6) final + { + set_activation_fusion(the_relu6->features(), locoex::FusedActFunc::RELU6); + + loco::replace(the_relu6).with(the_relu6->features()); + the_relu6->features(nullptr); + } + + void visit(locoex::TFLNode *) final { assert(false && "should not be called"); } +}; + +} // namespace + +namespace exo +{ + +bool FuseReluPass::run(loco::Graph *g) +{ + Collector collector; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (node->dialect() == locoex::TFLDialect::get()) + { + auto tfl_node = dynamic_cast<locoex::TFLNode *>(node); + tfl_node->accept(&collector); + } + } + + Performer performer; + + for (auto node : collector.candidates) + { + node->accept(&performer); + } + + return collector.candidates.size() > 0; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FuseReluPass.h b/compiler/exo/src/Pass/FuseReluPass.h new file mode 100644 index 000000000..1cd276b29 --- /dev/null +++ b/compiler/exo/src/Pass/FuseReluPass.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_FUSE_RELU_PASS_H__ +#define __PASS_FUSE_RELU_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse TFLRelu or TFLRelu6 into the TensorFlow Lite ops below: + * + * ADD, AVERAGE_POOL_2D, CONCATENATION, CONV_2D, DEPTHWISE_CONV_2D, + * FULLY_CONNECTED, L2_NORMALIZATION, L2_POOL_2D, MAX_POOL_2D, MUL + */ +struct FuseReluPass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FuseReluPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __PASS_FUSE_RELU_PASS_H__ diff --git a/compiler/exo/src/Pass/FuseReluPass.test.cpp b/compiler/exo/src/Pass/FuseReluPass.test.cpp new file mode 100644 index 000000000..6f83d4dd0 --- /dev/null +++ b/compiler/exo/src/Pass/FuseReluPass.test.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseReluPass.h" + +#include "Dialect/IR/TFLNodes.h" +#include "TestGraph.h" + +#include <loco.h> +#include <logo/RemoveDeadNodePass.h> + +#include <gtest/gtest.h> + +#include <type_traits> // for std::is_same + +namespace +{ + +void init(loco::Pull *pull) +{ + pull->dtype(loco::DataType::FLOAT32); + pull->shape({2, 3, 3, 2}); +} + +/// @brief Initializes TFLConv2D and related filter and bias +void init(locoex::TFLConv2D *conv2d, locoex::TFLConst *filter, locoex::TFLConst *bias) +{ + // set conv2d + { + conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE); + conv2d->padding(locoex::Padding::VALID); + } + + // set filter + { + filter->dtype(loco::DataType::FLOAT32); + filter->shape({2, 3, 3, 2}); + filter->size<loco::DataType::FLOAT32>(2 * 3 * 3 * 2); + + for (uint32_t x = 0; x < 2 * 3 * 3 * 2; x++) + filter->at<loco::DataType::FLOAT32>(x) = 0.0; + } + + // set bias + { + bias->dtype(loco::DataType::FLOAT32); + bias->shape({2}); + bias->size<loco::DataType::FLOAT32>(2); + + for (uint32_t x = 0; x < 2; x++) + bias->at<loco::DataType::FLOAT32>(x) = 0.0; + } +} + +} // namespace + +/// Test code called by TEST(..) +/// This tests whether Conv2D - FusedTFLType is fused. +template <class FusedTFLType, locoex::FusedActFunc FusedActFunc> void test() +{ + static_assert((std::is_same<FusedTFLType, locoex::TFLRelu>::value && + FusedActFunc == locoex::FusedActFunc::RELU) || + (std::is_same<FusedTFLType, locoex::TFLRelu6>::value && + FusedActFunc == locoex::FusedActFunc::RELU6), + "wrong template type"); + + exo::test::TestGraph g; + { + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto fusable_node = g.append<FusedTFLType>(conv2d); + + g.complete(fusable_node); + + init(g.pull); + init(conv2d, filter, bias); + } + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseReluPass>(); + test_phase.add_pass<logo::RemoveDeadNodePass>(); // to remove TFLRelu + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + ASSERT_TRUE(a_conv2d->fusedActivationFunction() == FusedActFunc); + + auto removed_fusable_node = exo::test::find_first_node_bytype<FusedTFLType>(g.graph()); + ASSERT_TRUE(removed_fusable_node == nullptr); +} + +// A case with Conv2D-Relu +TEST(FuseReluTest, Conv2D_Relu_basic) { test<locoex::TFLRelu, locoex::FusedActFunc::RELU>(); } + +// A case with Conv2D-Relu6 +TEST(FuseReluTest, Conv2D_Relu6_basic) { test<locoex::TFLRelu6, locoex::FusedActFunc::RELU6>(); } diff --git a/compiler/exo/src/Pass/FuseRsqrtPass.cpp b/compiler/exo/src/Pass/FuseRsqrtPass.cpp new file mode 100644 index 000000000..08d704139 --- /dev/null +++ b/compiler/exo/src/Pass/FuseRsqrtPass.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseRsqrtPass.h" + +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" + +namespace +{ + +/** + * @return Casted TFLDiv for fusable candidate, nullptr otherwise + * + * This helper checkes fusability with following conditions: + * - TFLDiv has no activation + * - TFLDiv's first argument is TFLConst with all value 1 + * - TFLDiv's second argument is TFLSqrt + */ +locoex::TFLDiv *as_candidate(loco::Node *node) +{ + auto div = dynamic_cast<locoex::TFLDiv *>(node); + if (not div) + return nullptr; + + // Cannot fuse Div with activation function + if (div->fusedActivationFunction() != locoex::FusedActFunc::NONE) + return nullptr; + + auto const_one = dynamic_cast<locoex::TFLConst *>(div->x()); + if (not const_one) + return nullptr; + + const loco::DataType FLOAT32 = loco::DataType::FLOAT32; + // TODO Support other dtype + EXO_ASSERT(const_one->dtype() == FLOAT32, "Only support FLOAT32 now"); + for (uint32_t i = 0; i < const_one->size<FLOAT32>(); ++i) + if (const_one->at<FLOAT32>(i) != 1.0f) + return nullptr; + + auto sqrt = dynamic_cast<locoex::TFLSqrt *>(div->y()); + if (not sqrt) + return nullptr; + + return div; +} + +void fuse_rsqrt(locoex::TFLDiv *div) +{ + auto sqrt = dynamic_cast<locoex::TFLSqrt *>(div->y()); + EXO_ASSERT(sqrt, "sqrt should be valid at this point"); + + // TFLRsqrt to replace + auto rsqrt = div->graph()->nodes()->create<locoex::TFLRsqrt>(); + rsqrt->x(sqrt->x()); + + // replace + loco::replace(div).with(rsqrt); +} + +} // namespace + +namespace exo +{ + +bool FuseRsqrtPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto div = as_candidate(node)) + { + fuse_rsqrt(div); + changed = true; + } + } + + return changed; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FuseRsqrtPass.h b/compiler/exo/src/Pass/FuseRsqrtPass.h new file mode 100644 index 000000000..1e60e4a49 --- /dev/null +++ b/compiler/exo/src/Pass/FuseRsqrtPass.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSE_RSQRT_PASS_H__ +#define __FUSE_RSQRT_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse TFLSqrt that is divided(TFLDiv) by 1, into TFLRsqrt + * + * <BEFORE> + * + * TFLConst(1) ------ + * \ + * A --- TFLSqrt --- TFLDiv --- B + * + * <AFTER> + * + * A --- TFLRsqrt --- B + */ +struct FuseRsqrtPass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FuseRsqrtPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __FUSE_RSQRT_PASS_H__ diff --git a/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp b/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp new file mode 100644 index 000000000..3f985a505 --- /dev/null +++ b/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseSquaredDifferencePass.h" + +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" + +namespace +{ + +/** + * @return Casted TFLMul for fusable candidate, nullptr otherwise + * + * This helper checkes fusability with following conditions: + * - TFLMul has no activation + * - TFLMul's first and second arguments are equal and TFLSub + */ +locoex::TFLMul *as_candidate(loco::Node *node) +{ + auto mul = dynamic_cast<locoex::TFLMul *>(node); + if (not mul) + return nullptr; + + // Cannot fuse mul with activation function + if (mul->fusedActivationFunction() != locoex::FusedActFunc::NONE) + return nullptr; + + if (mul->x() != mul->y()) + return nullptr; + + if (not dynamic_cast<locoex::TFLSub *>(mul->x())) + return nullptr; + + return mul; +} + +void fuse_squared_difference(locoex::TFLMul *mul) +{ + auto sub = dynamic_cast<locoex::TFLSub *>(mul->x()); + EXO_ASSERT(sub, "sub should be valid at this point"); + + // TFLSquaredDifference to replace + auto sq_diff = mul->graph()->nodes()->create<locoex::TFLSquaredDifference>(); + sq_diff->x(sub->x()); + sq_diff->y(sub->y()); + + // replace + loco::replace(mul).with(sq_diff); +} + +} // namespace + +namespace exo +{ + +bool FuseSquaredDifferencePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto mul = as_candidate(node)) + { + fuse_squared_difference(mul); + changed = true; + } + } + + return changed; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FuseSquaredDifferencePass.h b/compiler/exo/src/Pass/FuseSquaredDifferencePass.h new file mode 100644 index 000000000..dbc15149f --- /dev/null +++ b/compiler/exo/src/Pass/FuseSquaredDifferencePass.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSE_SQUARED_DIFFERENCE_PASS_H__ +#define __FUSE_SQUARED_DIFFERENCE_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse SquaredDifference pattern + * + * <BEFORE> + * + * A --- TFLSub --- TFLMul --- C + * / \ / + * B ---- ----- + * + * <AFTER> + * + * A --- TFLSquaredDifference --- C + * / + * B ---- + */ +struct FuseSquaredDifferencePass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FuseSquaredDifferencePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __FUSE_SQUARED_DIFFERENCE_PASS_H__ diff --git a/compiler/exo/src/Pass/MergeConcatNodesPass.cpp b/compiler/exo/src/Pass/MergeConcatNodesPass.cpp new file mode 100644 index 000000000..8945fcfce --- /dev/null +++ b/compiler/exo/src/Pass/MergeConcatNodesPass.cpp @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MergeConcatNodesPass.h" +#include "Dialect/IR/TFLNodes.h" + +#include <oops/InternalExn.h> + +#include <vector> + +namespace +{ + +bool canMerge(locoex::TFLConcatenation *node1, locoex::TFLConcatenation *node2) +{ + if (node1->fusedActivationFunction() != node2->fusedActivationFunction()) + return false; + + if (node1->axis() != node2->axis()) + return false; + + switch (node1->fusedActivationFunction()) + { + case locoex::FusedActFunc::NONE: + case locoex::FusedActFunc::RELU: + case locoex::FusedActFunc::RELU6: + return true; + + // case locoex::FusedActFunc::TANH: + // return false; + + default: + INTERNAL_EXN_V("Unknown FusedActFunc", oops::to_uint32(node1->fusedActivationFunction())); + } +} + +/** + * @brief Collect all the inputs of newly created TFLConcatenation nodes + * + * in:0 -------------------------------\ + * in:1 ---- TFLConcatenation:0 -------- TFLConcatenation:3 --- C + * (axis = 0, NONE) (axis = 0, NONE) + * in:2 ---/ / + * in:3 ---- TFLConcatenation:1 ------/ + * (axis = 1, NONE) / + * in:4 ---/ / + * in:5 ---- TFLConcatenation:2 ---/ + * (axis = 0, RELU) + * in:6 ---/ + * + * For exmaple, if graph is like above, dfs(TFLConcatenation:3) will + * return [in:0, in:1, in:2, TFLConcatenation:1, TFLConcatenation:2] + * + * TFLConcatenation:0 can be merged to TFLConcatenation:3, + * because axis and fusedActivationFunction are same. + * It means that [in:1, in:2] will be linked as inputs of new TFLConcatenation. + * + * However, TFLConcatenation:1 and TFLConcatenation:2 cannot be merged to + * TFLConcatenation:3 because axis and fusedActivationFunction of each are different. + * So [in:3, in:4, in:5, in:6] will not be linked as inputs of new TFLConcatenation + * and [TFLConcatenation:1, TFLConcatenation:2] will be linked instead. + * + * Therefore, inputs of newly created TFLConcatenation node for merging + * TFLConcatenation:3 will be [in:0, in:1, in:2, TFLConcatenation:1, TFLConcatenation:2] + * and dfs(TFLConcatenation:3) will return it. + * + * + * @note The input nodes should be traversed by LRV, + * which is from left to right (input:0 --> input:N) + */ +std::vector<loco::Node *> dfs(locoex::TFLConcatenation *root) +{ + std::vector<loco::Node *> res; + + for (uint32_t i = 0; i < root->numValues(); ++i) + { + auto input = dynamic_cast<locoex::TFLConcatenation *>(root->values(i)); + if (input != nullptr && canMerge(input, root)) + { + auto children = dfs(input); + for (auto child : children) + res.push_back(child); + } + else + { + res.push_back(root->values(i)); + } + } + + return res; +} + +} // namespace + +namespace exo +{ + +/** + * @brief Merge TFLConcatenate nodes whose axis and fusedActivationFunction are same + * + * [Before] + * in:0 -------------------------------\ + * in:1 ---- TFLConcatenation:0 -------- TFLConcatenation:3 --- C + * (axis = 0, NONE) (axis = 0, NONE) + * in:2 ---/ / + * in:3 ---- TFLConcatenation:1 ------/ + * (axis = 1, NONE) / + * in:4 ---/ / + * in:5 ---- TFLConcatenation:2 ---/ + * (axis = 0, RELU) + * in:6 ---/ + * + * [After] + * in:0 -------------------------------\ + * in:1 -------------------------------- TFLConcatenation:4 --- C + * (axis = 0, NONE) + * in:2 -------------------------------/ + * in:3 ---- TFLConcatenation:1 ------/ + * (axis = 1, NONE) / + * in:4 ---/ / + * in:5 ---- TFLConcatenation:2 ---/ + * (axis = 0, RELU) + * in:6 ---/ + * + * + * in:1 ---- TFLConcatenation:0 ---- + * (axis = 0, NONE) + * in:2 ---/ + * + * + * ---- TFLConcatenation:3 ---- + * (axis = 0, NONE) + */ +bool MergeConcatNodesPass::run(loco::Graph *graph) +{ + // Let's enumerate nodes required to compute output nodes + auto active_nodes = loco::active_nodes(loco::output_nodes(graph)); + + // Find TFLConcatenation nodes which have another TFLConcatenation nodes + // as inputs, with same axis and same fusedActivationFunction + std::vector<locoex::TFLConcatenation *> candidates; + for (auto node : active_nodes) + { + if (auto concat = dynamic_cast<locoex::TFLConcatenation *>(node)) + { + for (uint32_t i = 0; i < concat->numValues(); ++i) + { + auto input = dynamic_cast<locoex::TFLConcatenation *>(concat->values(i)); + if (input != nullptr && canMerge(input, concat)) + { + candidates.push_back(concat); + break; + } + } + } + } + + // Merge multiple TFLConcatenation nodes as one TFLConcatenation node + for (auto node : candidates) + { + auto inputs = dfs(node); + + auto new_concat = graph->nodes()->create<locoex::TFLConcatenation>(inputs.size()); + new_concat->axis(node->axis()); + new_concat->fusedActivationFunction(node->fusedActivationFunction()); + + for (uint32_t i = 0; i < inputs.size(); ++i) + new_concat->values(i, inputs.at(i)); + + loco::replace(node).with(new_concat); + for (uint32_t i = 0; i < node->numValues(); ++i) + node->values(i, nullptr); + } + + return candidates.size() > 0; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/MergeConcatNodesPass.h b/compiler/exo/src/Pass/MergeConcatNodesPass.h new file mode 100644 index 000000000..823214f43 --- /dev/null +++ b/compiler/exo/src/Pass/MergeConcatNodesPass.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_MERGE_CONCAT_NODES_H__ +#define __PASS_MERGE_CONCAT_NODES_H__ + +#include <loco.h> +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Merge concat nodes whose axis and fusedActivationFunction are same + * + */ +class MergeConcatNodesPass : public logo::Pass +{ +public: + virtual const char *name(void) const { return "exo::MergeConcatNodesPass"; } + +public: + bool run(loco::Graph *graph); +}; + +} // namespace exo + +#endif // __PASS_MERGE_CONCAT_NODES_H__ diff --git a/compiler/exo/src/Pass/ShapeInferencePass.cpp b/compiler/exo/src/Pass/ShapeInferencePass.cpp new file mode 100644 index 000000000..bc60f91c4 --- /dev/null +++ b/compiler/exo/src/Pass/ShapeInferencePass.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ShapeInferencePass.h" + +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/Service/TFLShapeInferenceRule.h" + +#include "Dialect/IR/CircleDialect.h" +#include "Dialect/Service/CircleShapeInferenceRule.h" + +#include <loco.h> +#include <loco/IR/CanonicalDialect.h> +#include <loco/Service/CanonicalShapeInferenceRule.h> +#include <loco/Service/ShapeInference.h> +#include <loco/Service/MultiDialectShapeInferenceRule.h> + +#include <locoex/COpDialect.h> +#include <locoex/Service/COpShapeInferenceRule.h> + +namespace exo +{ + +/** + * @note Currently, TFL and Circle backend share this inference. However, TFL + * backend does not require rule for Circle dialect. + * TODO Make dedicated inference pass for Circle Dialect. + */ +bool ShapeInferencePass::run(loco::Graph *g) +{ + loco::CanonicalShapeInferenceRule canonical_rule; + locoex::TFLShapeInferenceRule tfl_rule; + locoex::CircleShapeInferenceRule circle_rule; + locoex::COpShapeInferenceRule cop_rule; + + loco::MultiDialectShapeInferenceRule rules; + + rules.bind(loco::CanonicalDialect::get(), &canonical_rule) + .bind(locoex::TFLDialect::get(), &tfl_rule) + .bind(locoex::CircleDialect::get(), &circle_rule) + .bind(locoex::COpDialect::get(), &cop_rule); + + return loco::apply(&rules).to(g); +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/ShapeInferencePass.h b/compiler/exo/src/Pass/ShapeInferencePass.h new file mode 100644 index 000000000..518c87403 --- /dev/null +++ b/compiler/exo/src/Pass/ShapeInferencePass.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_SHAPE_INFERENCE_PASS_H__ +#define __PASS_SHAPE_INFERENCE_PASS_H__ + +#include <loco.h> +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Pass to infer shape of nodes + */ +class ShapeInferencePass : public logo::Pass +{ +public: + virtual const char *name(void) const { return "exo::ShapeInferencePass"; } + +public: + bool run(loco::Graph *graph); +}; + +} // namespace exo + +#endif //__PASS_SHAPE_INFERENCE_PASS_H__ diff --git a/compiler/exo/src/Pass/TypeInferencePass.cpp b/compiler/exo/src/Pass/TypeInferencePass.cpp new file mode 100644 index 000000000..31d4f13b6 --- /dev/null +++ b/compiler/exo/src/Pass/TypeInferencePass.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TypeInferencePass.h" + +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/Service/TFLTypeInferenceRule.h" + +#include "Dialect/IR/CircleDialect.h" +#include "Dialect/Service/CircleTypeInferenceRule.h" + +#include <loco.h> +#include <loco/IR/CanonicalDialect.h> +#include <loco/Service/TypeInference.h> + +#include <locoex/COpDialect.h> +#include <locoex/Service/COpTypeInference.h> + +namespace exo +{ + +/** + * @note Currently, TFL and Circle backend share this inference. However, TFL + * backend does not require rule for Circle dialect. + * TODO Make dedicated inference pass for Circle Dialect. + */ +bool TypeInferencePass::run(loco::Graph *g) +{ + loco::CanonicalTypeInferenceRule canonical_rule; + locoex::TFLTypeInferenceRule tfl_rule; + locoex::CircleTypeInferenceRule circle_rule; + locoex::COpTypeInferenceRule cop_rule; + + loco::MultiDialectTypeInferenceRule rules; + + rules.bind(loco::CanonicalDialect::get(), &canonical_rule) + .bind(locoex::TFLDialect::get(), &tfl_rule) + .bind(locoex::CircleDialect::get(), &circle_rule) + .bind(locoex::COpDialect::get(), &cop_rule); + + return loco::apply(&rules).to(g); +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/TypeInferencePass.h b/compiler/exo/src/Pass/TypeInferencePass.h new file mode 100644 index 000000000..3ede587a0 --- /dev/null +++ b/compiler/exo/src/Pass/TypeInferencePass.h @@ -0,0 +1,42 @@ + +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_TYPE_INFERENCE_PASS_H__ +#define __PASS_TYPE_INFERENCE_PASS_H__ + +#include <loco.h> + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Pass to infer type of nodes + */ +class TypeInferencePass : public logo::Pass +{ +public: + virtual const char *name(void) const { return "exo::TypeInferencePass"; } + +public: + bool run(loco::Graph *graph); +}; + +} // namespace exo + +#endif //__PASS_TYPE_INFERENCE_PASS_H__ diff --git a/compiler/exo/src/Passes.cpp b/compiler/exo/src/Passes.cpp new file mode 100644 index 000000000..99d229c9c --- /dev/null +++ b/compiler/exo/src/Passes.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Passes.h" + +// This file is to make sure that Passes.h be compiled diff --git a/compiler/exo/src/Passes.h b/compiler/exo/src/Passes.h new file mode 100644 index 000000000..2a702d01d --- /dev/null +++ b/compiler/exo/src/Passes.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASSES_H__ +#define __PASSES_H__ + +// Please add in alphabetical order +// Please append 'Pass' suffix to Pass class and file names + +#include "Pass/FoldReshapeOfConstPass.h" +#include "Pass/FoldTransposeOfConstPass.h" +#include "Pass/FuseBiasAddPass.h" +#include "Pass/FuseInstanceNormPass.h" +#include "Pass/FuseReluPass.h" +#include "Pass/FuseRsqrtPass.h" +#include "Pass/FuseSquaredDifferencePass.h" +#include "Pass/MergeConcatNodesPass.h" +#include "Pass/ShapeInferencePass.h" +#include "Pass/TypeInferencePass.h" + +#include <logo/RemoveDeadNodePass.h> +#include <logo/RemoveForwardNodePass.h> +#include <logo/SimplifyDomainConversionPass.h> + +#endif // __PASSES_H__ diff --git a/compiler/exo/src/ProgressReporter.cpp b/compiler/exo/src/ProgressReporter.cpp new file mode 100644 index 000000000..ff919dae8 --- /dev/null +++ b/compiler/exo/src/ProgressReporter.cpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ProgressReporter.h" + +#include "Log.h" +#include "LogHelper.h" + +#include <logo/Phase.h> +#include <logo/Pass.h> + +#include <cassert> + +namespace +{ + +char to_char(bool b) { return b ? 'Y' : 'N'; } + +const char *to_str(logo::PhaseStrategy s) +{ + switch (s) + { + case logo::PhaseStrategy::Saturate: + return "Saturate"; + case logo::PhaseStrategy::Restart: + return "Restart"; + } + assert(false); + return ""; +} + +} // namespace + +namespace exo +{ + +void ProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseBegin> *) +{ + LOGGER(prime); + + INFO(prime) << "=============================================================="; + INFO(prime) << "exo::PhaseRunner<" << to_str(strategy()) << ">"; + INFO(prime) << "Initial graph"; + INFO(prime) << fmt(graph()); +} + +void ProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseEnd> *) +{ + LOGGER(prime); + + INFO(prime) << "exo::PhaseRunner<" << to_str(strategy()) << "> - done"; +} + +void ProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassBegin> *info) +{ + LOGGER(prime); + + INFO(prime) << "--------------------------------------------------------------"; + INFO(prime) << "Before " << logo::pass_name(info->pass()); +} + +void ProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassEnd> *info) +{ + LOGGER(prime); + + INFO(prime) << "After " << logo::pass_name(info->pass()) + << " (changed: " << to_char(info->changed()) << ")"; + INFO(prime) << fmt(graph()); +} + +} // namespace exo diff --git a/compiler/exo/src/ProgressReporter.h b/compiler/exo/src/ProgressReporter.h new file mode 100644 index 000000000..b0f420df9 --- /dev/null +++ b/compiler/exo/src/ProgressReporter.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PROGRESSREPORTER_H__ +#define __PROGRESSREPORTER_H__ + +#include <logo/Phase.h> + +#include <loco.h> + +namespace exo +{ + +class ProgressReporter : public logo::PhaseEventListener +{ +public: + ProgressReporter(loco::Graph *graph, logo::PhaseStrategy strategy) + : _graph{graph}, _strategy{strategy} + { + // DO NOTHING + } + +public: + void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseBegin> *) override; + void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseEnd> *) override; + void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassBegin> *) override; + void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassEnd> *) override; + +public: + loco::Graph *graph(void) const { return _graph; } + logo::PhaseStrategy strategy(void) const { return _strategy; } + +private: + loco::Graph *_graph; + logo::PhaseStrategy _strategy; +}; + +} // namespace exo + +#endif // __PROGRESSREPORTER_H__ diff --git a/compiler/exo/src/ShapeInference.cpp b/compiler/exo/src/ShapeInference.cpp new file mode 100644 index 000000000..bceb1495f --- /dev/null +++ b/compiler/exo/src/ShapeInference.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ShapeInference.h" +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/Service/TFLShapeInferenceRule.h" + +#include <loco/IR/CanonicalNode.h> +#include <loco/IR/CanonicalDialect.h> +#include <loco/IR/CanonicalNodeVisitor.h> +#include <loco/Service/ShapeInference.h> +#include <loco/Service/CanonicalShapeInferenceRule.h> +#include <loco/Service/MultiDialectShapeInferenceRule.h> + +#include <locoex/COpCall.h> +#include <locoex/COpDialect.h> +#include <locoex/Service/COpShapeInferenceRule.h> + +namespace exo +{ + +ShapeDescription ShapeInference::get(loco::Node *node) +{ + // TODO Adjust indentation level + { + assert(loco::shape_known(node)); + return to_shape_description(loco::shape_get(node)); + } +} + +} // namespace exo diff --git a/compiler/exo/src/ShapeInference.h b/compiler/exo/src/ShapeInference.h new file mode 100644 index 000000000..ec141ccfc --- /dev/null +++ b/compiler/exo/src/ShapeInference.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __SHAPE_INFERENCE_H__ +#define __SHAPE_INFERENCE_H__ + +#include "ExporterUtils.h" + +#include <loco/IR/Nodes.h> + +namespace exo +{ + +/** + * @brief Get the shape of each node as a node annotation + * + * HOW TO USE + * + * ShapeInference::get(g->nodes()->at(..)); + */ +struct ShapeInference +{ + static ShapeDescription get(loco::Node *node); +}; + +} // namespace exo + +#endif // __SHAPE_INFERENCE_H__ diff --git a/compiler/exo/src/TFLite/TFLExporter.cpp b/compiler/exo/src/TFLite/TFLExporter.cpp new file mode 100644 index 000000000..cf002b3e1 --- /dev/null +++ b/compiler/exo/src/TFLite/TFLExporter.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "exo/TFLExporter.h" + +#include "TFLExporterImpl.h" + +#include <stdex/Memory.h> + +#include <oops/InternalExn.h> + +#include <fstream> + +namespace exo +{ + +TFLExporter::TFLExporter(loco::Graph *graph) : _impl(stdex::make_unique<Impl>(graph)) +{ + // NOTHING TO DO +} + +TFLExporter::~TFLExporter() = default; + +void TFLExporter::dumpToFile(const char *path) const +{ + const char *ptr = _impl->getBufferPointer(); + const size_t size = _impl->getBufferSize(); + + if (!ptr) + INTERNAL_EXN("Graph was not serialized by FlatBuffer for some reason"); + + std::ofstream file(path, std::ofstream::binary); + file.write(ptr, size); +} + +} // namespace exo diff --git a/compiler/exo/src/TFLite/TFLExporterImpl.cpp b/compiler/exo/src/TFLite/TFLExporterImpl.cpp new file mode 100644 index 000000000..07adbfb9d --- /dev/null +++ b/compiler/exo/src/TFLite/TFLExporterImpl.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLExporterImpl.h" + +#include "Convert.h" +#include "ExoOptimize.h" + +#include "TFLTensorExporter.h" +#include "TFLOperationExporter.h" +#include "TFLExporterUtils.h" + +#include "Log.h" +#include "Knob.h" + +#include <oops/InternalExn.h> + +#include <cassert> +#include <unordered_map> +#include <string> +#include <stdexcept> + +namespace +{ + +using namespace exo; +using namespace exo::tflite_detail; + +void registerGraphInputTensors(loco::Graph *graph, SubGraphContext &ctx) +{ + for (uint32_t n = 0; n < graph->inputs()->size(); ++n) + { + auto node = loco::pull_node(graph, n); + assert(node != nullptr); + ctx._inputs.push_back(get_tensor_index(node)); + } +} + +void registerGraphOutputTensors(loco::Graph *graph, SubGraphContext &ctx) +{ + for (uint32_t n = 0; n < graph->outputs()->size(); ++n) + { + auto push = loco::push_node(graph, n); + assert(push != nullptr); + auto node = push->from(); + assert(node != nullptr); + ctx._outputs.push_back(get_tensor_index(node)); + } +} + +} // namespace + +namespace +{ +using namespace tflite; +using namespace flatbuffers; + +Offset<Vector<Offset<OperatorCode>>> +encodeOperatorCodes(FlatBufferBuilder &builder, std::unordered_map<OpCode, uint32_t> &opcodes, + std::unordered_map<OpCode, std::string> &custom_opcodes) +{ + std::vector<Offset<OperatorCode>> operator_codes_vec(opcodes.size()); + for (auto it : opcodes) + { + uint32_t idx = it.second; + if (it.first.opcode != BuiltinOperator_CUSTOM) + { + operator_codes_vec[idx] = CreateOperatorCode(builder, it.first.opcode); + } + else // custom op + { + auto opCode = it.first; + auto custom_code = custom_opcodes.find(opCode); + if (custom_code == custom_opcodes.end()) + INTERNAL_EXN("Cannot find code for custom op"); + + operator_codes_vec[idx] = + CreateOperatorCode(builder, it.first.opcode, builder.CreateString(custom_code->second)); + } + } + return builder.CreateVector(operator_codes_vec); +} + +} // namespace + +namespace exo +{ + +using namespace exo::tflite_detail; +using namespace tflite; +using namespace flatbuffers; + +TFLExporter::Impl::Impl(loco::Graph *graph) { exportGraph(graph); } + +::flatbuffers::Offset<::tflite::SubGraph> TFLExporter::Impl::exportSubgraph(SerializedModelData &gd) +{ + auto tensors = _builder.CreateVector(gd._tensors); + auto inputs = _builder.CreateVector(gd._inputs); + auto outputs = _builder.CreateVector(gd._outputs); + auto operators = _builder.CreateVector(gd._operators); + auto subgraph = CreateSubGraph(_builder, tensors, inputs, outputs, operators); + return subgraph; +} + +void TFLExporter::Impl::exportGraph(loco::Graph *graph) +{ + LOGGER(l); + + // IR-level conversion and optimization + { + convert_to_TFLNodes(graph); + set(Dialect::TFLITE); + optimize(graph); + } + + _builder.Clear(); + + SerializedModelData gd; + + // This version is taken from comment in fbs + constexpr uint32_t version = 3; + + registerGraphIOName(graph, gd); + + // parse graph into SerializedModelData structure + exportOpDefinedTensors(graph, _builder, gd); + + // NOTE Invoke these register functions only after each node is annotated with its tensor_index + registerGraphInputTensors(graph, gd); + registerGraphOutputTensors(graph, gd); + + exportNodes(graph, _builder, gd); + + // encode operator codes + auto operator_codes = + encodeOperatorCodes(_builder, gd._operator_codes, gd._custom_operator_codes); + + // Subgraphs + Offset<SubGraph> subgraph = exportSubgraph(gd); + auto subgraphs = _builder.CreateVector(std::vector<Offset<SubGraph>>{subgraph}); + + // Description + std::string description_str = "nnpackage"; + auto description = _builder.CreateString(description_str); + + // create array of buffers + auto buffers = _builder.CreateVector(gd._buffers); + + // empty metadata + std::vector<int> metadata_buffer_vec; + auto metadata_buffer = _builder.CreateVector(metadata_buffer_vec); + + // Model + auto model_offset = CreateModel(_builder, version, operator_codes, subgraphs, description, + buffers, metadata_buffer); + FinishModelBuffer(_builder, model_offset); +} + +const char *TFLExporter::Impl::getBufferPointer() const +{ + return reinterpret_cast<const char *>(_builder.GetBufferPointer()); +} + +size_t TFLExporter::Impl::getBufferSize() const { return _builder.GetSize(); } + +} // namespace exo diff --git a/compiler/exo/src/TFLite/TFLExporterImpl.h b/compiler/exo/src/TFLite/TFLExporterImpl.h new file mode 100644 index 000000000..01c549a43 --- /dev/null +++ b/compiler/exo/src/TFLite/TFLExporterImpl.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TFL_EXPORTER_IMPL_H__ +#define __TFL_EXPORTER_IMPL_H__ + +#include "exo/TFLExporter.h" +#include "schema_generated.h" + +#include <loco.h> + +namespace exo +{ + +namespace tflite_detail +{ + +struct SerializedModelData; + +} // namespace tflite_detail + +using namespace tflite_detail; + +/** + * internal implementation of interface exporter class + */ +class TFLExporter::Impl +{ +public: + Impl() = delete; + ~Impl() = default; + + explicit Impl(loco::Graph *graph); + + /** + * @return pointer to buffer with serialized graph + */ + const char *getBufferPointer() const; + + /** + * @return size of buffer with serialized graph + */ + size_t getBufferSize() const; + +private: + /** + * @brief create Subgraph using data stored in SerializedModelData + * @param gd information about serializer parts of model + * @return offset in buffer corresponding to serialized subgraph + */ + flatbuffers::Offset<tflite::SubGraph> exportSubgraph(SerializedModelData &gd); + + /** + * @brief root function that writes graph into internal buffer + * @param graph + */ + void exportGraph(loco::Graph *graph); + +private: + flatbuffers::FlatBufferBuilder _builder; +}; + +} // namespace exo + +#endif // __TFL_EXPORTER_IMPL_H__ diff --git a/compiler/exo/src/TFLite/TFLExporterImpl.test.cpp b/compiler/exo/src/TFLite/TFLExporterImpl.test.cpp new file mode 100644 index 000000000..7d74223c5 --- /dev/null +++ b/compiler/exo/src/TFLite/TFLExporterImpl.test.cpp @@ -0,0 +1,413 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLExporterImpl.h" + +#include "schema_generated.h" + +#include "TestGraph.h" +#include "GraphBlock.h" +#include "Knob.h" + +#include <loco/IR/PermutingCodec.h> +#include <stdex/Memory.h> + +#include <gtest/gtest.h> + +namespace +{ + +class TFLExporterImplTests : public ::testing::Test +{ +public: + TFLExporterImplTests() { _graph = loco::make_graph(); } + +public: + virtual ~TFLExporterImplTests() = default; + +protected: + loco::Graph *graph(void) { return _graph.get(); } + + template <typename NodeT> NodeT *make_node(void); + +private: + std::unique_ptr<loco::Graph> _graph; +}; + +template <typename NodeT> NodeT *TFLExporterImplTests::make_node(void) +{ + return graph()->nodes()->create<NodeT>(); +} + +template <> loco::FeatureEncode *TFLExporterImplTests::make_node(void) +{ + loco::FeatureEncode *encode_layer = graph()->nodes()->create<loco::FeatureEncode>(); + + auto encoder = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>(); + (*encoder->perm())[loco::FeatureAxis::Count] = 0; + (*encoder->perm())[loco::FeatureAxis::Depth] = 1; + (*encoder->perm())[loco::FeatureAxis::Height] = 2; + (*encoder->perm())[loco::FeatureAxis::Width] = 3; + encode_layer->encoder(std::move(encoder)); + + return encode_layer; +} + +template <> loco::FeatureDecode *TFLExporterImplTests::make_node(void) +{ + loco::FeatureDecode *decode_layer = graph()->nodes()->create<loco::FeatureDecode>(); + + auto decoder = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>(); + (*decoder->perm())[loco::FeatureAxis::Count] = 0; + (*decoder->perm())[loco::FeatureAxis::Depth] = 1; + (*decoder->perm())[loco::FeatureAxis::Height] = 2; + (*decoder->perm())[loco::FeatureAxis::Width] = 3; + decode_layer->decoder(std::move(decoder)); + + return decode_layer; +} + +} // namespace + +// TODO TFLAdd + +// TODO TFLAveragePool2D + +TEST_F(TFLExporterImplTests, Concatenate) +{ + auto pull1 = make_node<loco::Pull>(); + { + pull1->dtype(loco::DataType::FLOAT32); + pull1->shape({1, 2, 3, 4}); + } + auto pull2 = make_node<loco::Pull>(); + { + pull2->dtype(loco::DataType::FLOAT32); + pull2->shape({1, 2, 3, 4}); + } + auto concat = make_node<loco::TensorConcat>(); + { + concat->lhs(pull1); + concat->rhs(pull2); + } + auto push = make_node<loco::Push>(); + { + push->from(concat); + } + + auto input1 = graph()->inputs()->create(); + { + input1->name("input1"); + loco::link(input1, pull1); + } + auto input2 = graph()->inputs()->create(); + { + input2->name("input2"); + loco::link(input2, pull2); + } + auto output = graph()->outputs()->create(); + { + output->name("output"); + loco::link(output, push); + } + + exo::TFLExporter::Impl exporter{graph()}; + + // TODO Add more checks + SUCCEED(); +} + +// TODO TFLConv2D + +// TODO TFLDepthwiseConv2D + +// TODO TFLDiv + +// TODO TFLMaxPool2D + +// TODO TFLMul + +TEST_F(TFLExporterImplTests, Relu6) +{ + auto pull = make_node<loco::Pull>(); + { + pull->dtype(loco::DataType::FLOAT32); + pull->shape({1, 8, 8, 3}); + } + auto relu6 = make_node<loco::ReLU6>(); + { + relu6->input(pull); + } + auto push = make_node<loco::Push>(); + { + push->from(relu6); + } + + auto input = graph()->inputs()->create(); + { + input->name("input"); + loco::link(input, pull); + } + auto output = graph()->outputs()->create(); + { + output->name("output"); + loco::link(output, push); + } + + exo::TFLExporter::Impl exporter{graph()}; + + // TODO Add more checks + SUCCEED(); +} + +// TODO TFLRelu6 + +// TODO TFLReshape + +// TODO TFLSoftmax + +// TODO TFLSqrt + +// TODO TFLSub + +// TODO TFLTanh + +TEST(TFLExporterImplTest, Transpose_simple) +{ + exo::test::ExampleGraph<exo::test::ExampleGraphType::Transpose> g; + + // pull attribute + { + g.pull->dtype(loco::DataType::FLOAT32); + g.pull->shape({1, 2, 2, 3}); + } + + // transpose attribute + { + g.transpose->perm()->size(4); + g.transpose->perm()->axis(0) = 1; + g.transpose->perm()->axis(1) = 2; + g.transpose->perm()->axis(2) = 3; + g.transpose->perm()->axis(3) = 0; + } + + exo::TFLExporter::Impl exporter{g.graph()}; + { + auto model = tflite::GetModel(exporter.getBufferPointer()); + auto operators = model->subgraphs()->Get(0)->operators(); + + assert(operators->Length() == 1); + + int n = 0; // op index of Transpose in tflite file + + auto opcode_index = operators->Get(n)->opcode_index(); + + ASSERT_EQ(model->operator_codes()->Get(opcode_index)->builtin_code(), + tflite::BuiltinOperator_TRANSPOSE); + + auto perm = operators->Get(n)->inputs()->Get(1); + + auto perm_tensor = model->subgraphs()->Get(0)->tensors()->Get(perm); + ASSERT_EQ(perm_tensor->type(), tflite::TensorType::TensorType_INT32); + ASSERT_EQ(perm_tensor->shape()->size(), 1); + ASSERT_EQ(perm_tensor->shape()->Get(0), 4); + + auto bufs = (model->buffers()); + auto *perm_buf = + reinterpret_cast<const int32_t *>(bufs->Get(perm_tensor->buffer())->data()->data()); + + ASSERT_EQ(perm_buf[0], 1); + ASSERT_EQ(perm_buf[1], 2); + ASSERT_EQ(perm_buf[2], 3); + ASSERT_EQ(perm_buf[3], 0); + } +} + +/* + test case: + Pull ----- FeatureEncode ---- FeatureDecode --- Push + 0 -----------> H ---------+ O 0 + 1 W +----> H -----------> 1 + 2 I(depth) W 2 + 3 O(coutn) I 3 + + axis 0 ----------> H --------------> H -----------> 1 + axis 1 ----------> W --------------> W -----------> 2 + axis 2 ----------> I --------------> I -----------> 3 + axis 3 ----------> O --------------> O -----------> 0 + + So, perm vector of Tranpose = [3, 0, 1, 2]. + Please refer to loco::TensorTranspose about the definition of perm vector. +*/ +TEST(TFLExporterImplTest, Transpose_from_FilterEncode_FilterDecode) +{ + exo::test::ExampleGraph<exo::test::ExampleGraphType::FilterEncode_FilterDecode> g; + + // pull attribute + { + g.pull->dtype(loco::DataType::FLOAT32); + g.pull->shape({1, 2, 3, 4}); // whatever value of rank 4 + } + + exo::TFLExporter::Impl exporter{g.graph()}; + { + auto model = tflite::GetModel(exporter.getBufferPointer()); + auto operators = model->subgraphs()->Get(0)->operators(); + + assert(operators->Length() == 1); + + int n = 0; // op index of Transpose in tflite file + + auto opcode_index = operators->Get(n)->opcode_index(); + + ASSERT_EQ(model->operator_codes()->Get(opcode_index)->builtin_code(), + tflite::BuiltinOperator_TRANSPOSE); + + auto perm = operators->Get(n)->inputs()->Get(1); + + auto perm_tensor = model->subgraphs()->Get(0)->tensors()->Get(perm); + ASSERT_EQ(perm_tensor->type(), tflite::TensorType::TensorType_INT32); + ASSERT_EQ(perm_tensor->shape()->size(), 1); + ASSERT_EQ(perm_tensor->shape()->Get(0), 4); + + auto bufs = (model->buffers()); + auto *perm_buf = + reinterpret_cast<const int32_t *>(bufs->Get(perm_tensor->buffer())->data()->data()); + ASSERT_EQ(perm_buf[0], 3); + ASSERT_EQ(perm_buf[1], 0); + ASSERT_EQ(perm_buf[2], 1); + ASSERT_EQ(perm_buf[3], 2); + } +} + +/** + * What happens when there is a mismatch between generation and execution order!? + */ +TEST_F(TFLExporterImplTests, Regression_0000) +{ + // This test was written without considering fusion. + // For this reason, this check is needed. + // TODO Rewrite this test + if (exo::get<exo::Knob::UseFuseReluPass>()) + return; + + // Execution Order: MaxPool2D -> ReLU + // Generation Order: ReLU -> MaxPool2D + auto pull = make_node<loco::Pull>(); + { + pull->dtype(loco::DataType::FLOAT32); + pull->shape({1, 8, 8, 3}); + } + auto relu = make_node<loco::ReLU>(); + auto encode = exo::make_feature_encode<exo::FeatureLayout::NHWC>(pull); + auto maxpool = make_node<loco::MaxPool2D>(); + auto decode = exo::make_feature_decode<exo::FeatureLayout::NHWC>(relu); + auto push = make_node<loco::Push>(); + + ASSERT_EQ(maxpool->window()->vertical(), 1); + ASSERT_EQ(maxpool->window()->horizontal(), 1); + + maxpool->ifm(encode); + relu->input(maxpool); + push->from(decode); + + auto input = graph()->inputs()->create(); + { + input->name("input"); + loco::link(input, pull); + } + auto output = graph()->outputs()->create(); + { + output->name("output"); + loco::link(output, push); + } + + exo::TFLExporter::Impl exporter{graph()}; + { + int64_t maxpool_execution_index = -1; + int64_t relu_exeuction_index = -1; + + auto model = tflite::GetModel(exporter.getBufferPointer()); + auto operators = model->subgraphs()->Get(0)->operators(); + + for (uint32_t n = 0; n < operators->Length(); ++n) + { + auto opcode_index = operators->Get(n)->opcode_index(); + + switch (model->operator_codes()->Get(opcode_index)->builtin_code()) + { + case tflite::BuiltinOperator_RELU: + ASSERT_EQ(relu_exeuction_index, -1); + relu_exeuction_index = static_cast<int64_t>(n); + break; + case tflite::BuiltinOperator_MAX_POOL_2D: + ASSERT_EQ(maxpool_execution_index, -1); + maxpool_execution_index = static_cast<int64_t>(n); + break; + default: + break; + } + } + + ASSERT_NE(maxpool_execution_index, -1); + ASSERT_NE(relu_exeuction_index, -1); + // maxpool SHOULD precede ReLU + ASSERT_LT(maxpool_execution_index, relu_exeuction_index); + } +} + +/** + * @brief Test exporter buffer generation + */ +TEST_F(TFLExporterImplTests, Regression_0001) +{ + auto cgen = make_node<loco::ConstGen>(); + cgen->rank(1); + cgen->dim(0) = 2; + cgen->dtype(loco::DataType::FLOAT32); + cgen->size<loco::DataType::FLOAT32>(2); + cgen->at<loco::DataType::FLOAT32>(0) = 3.3f; + cgen->at<loco::DataType::FLOAT32>(1) = 1.1f; + + auto push = make_node<loco::Push>(); + push->from(cgen); + + auto output = graph()->outputs()->create(); + { + output->name("output"); + loco::link(output, push); + } + + exo::TFLExporter::Impl exporter{graph()}; + { + auto model = tflite::GetModel(exporter.getBufferPointer()); + auto buffers = model->buffers(); + + // 0'th empty buffer + ConstGen data + ConstGen node output + ASSERT_EQ(buffers->Length(), 3); + + // 0'th should be empty buffer + auto buffer_0 = (*buffers)[0]; + auto array_0 = buffer_0->data(); + ASSERT_EQ(array_0, nullptr); + + // 1'st should be ConstGen data which is two float + auto buffer_1 = (*buffers)[1]; + auto array_1 = buffer_1->data(); + size_t size_1 = array_1->size(); + ASSERT_EQ(size_1, 2 * sizeof(float)); + } +} diff --git a/compiler/exo/src/TFLite/TFLExporterUtils.cpp b/compiler/exo/src/TFLite/TFLExporterUtils.cpp new file mode 100644 index 000000000..d35afc9aa --- /dev/null +++ b/compiler/exo/src/TFLite/TFLExporterUtils.cpp @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLExporterUtils.h" + +#include <oops/InternalExn.h> + +namespace exo +{ + +tflite::ActivationFunctionType to_tflite_actfunc(locoex::FusedActFunc func) +{ + switch (func) + { + case locoex::FusedActFunc::NONE: + return tflite::ActivationFunctionType_NONE; + case locoex::FusedActFunc::RELU: + return tflite::ActivationFunctionType_RELU; + case locoex::FusedActFunc::RELU6: + return tflite::ActivationFunctionType_RELU6; + default: + INTERNAL_EXN_V("Unsupported locoex FusedActFunc Type", oops::to_uint32(func)); + } +} + +} // namespace exo + +namespace exo +{ +namespace tflite_detail +{ + +uint32_t SerializedModelData::registerBuiltinOpcode(tflite::BuiltinOperator builtin_code) +{ + auto it = _operator_codes.find(OpCode{builtin_code}); + if (it != _operator_codes.end()) + { + return it->second; + } + auto idx = static_cast<uint32_t>(_operator_codes.size()); + _operator_codes.emplace(OpCode{builtin_code}, idx); + return idx; +} + +uint32_t SerializedModelData::registerCustomOpcode(const std::string &custom_op) +{ + tflite::BuiltinOperator custom_code = tflite::BuiltinOperator_CUSTOM; + auto idx = registerBuiltinOpcode(custom_code); + _custom_operator_codes.emplace(OpCode{custom_code}, custom_op); + return idx; +} + +tflite::Padding getOpPadding(const loco::Padding2D *pad, const loco::Stride<2> *stride, + const ShapeDescription &ifm, const ShapeDescription &ofm) +{ + // VALID padding + if (pad->top() == 0 && pad->bottom() == 0 && pad->left() == 0 && pad->right() == 0) + return tflite::Padding_VALID; + + // SAME padding + // + // For same padding, by definition, following equation should hold: + // O = floor((I - 1) / S) + 1 + // where input size I, output size O, stride S + // + // NOTE input and output 'feature' map are shape of NHWC + bool same_padding_criterion_1 = + (static_cast<uint32_t>(ofm._dims[1]) == (ifm._dims[1] - 1) / stride->vertical() + 1) && + (static_cast<uint32_t>(ofm._dims[2]) == (ifm._dims[2] - 1) / stride->horizontal() + 1); + + // For same padding, rear padding is same or bigger than front padding by at most 1 + bool same_padding_criterion_2 = + (pad->top() <= pad->bottom()) && (pad->bottom() <= pad->top() + 1) && + (pad->left() <= pad->right()) && (pad->right() <= pad->left() + 1); + + if (same_padding_criterion_1 && same_padding_criterion_2) + return tflite::Padding_SAME; + + INTERNAL_EXN("NYI for custom PAD"); +} + +tflite::Padding getOpPadding(const locoex::Padding pad) +{ + if (pad == locoex::Padding::VALID) + return tflite::Padding_VALID; + if (pad == locoex::Padding::SAME) + return tflite::Padding_SAME; + + INTERNAL_EXN_V("Unknown padding", oops::to_uint32(pad)); +} + +void registerGraphIOName(loco::Graph *graph, SerializedModelData &gd) +{ + for (uint32_t in = 0; in < graph->inputs()->size(); ++in) + { + auto pull = loco::pull_node(graph, in); + auto name = graph->inputs()->at(in)->name(); + + gd._pull_to_name[pull] = name; + } + for (uint32_t out = 0; out < graph->outputs()->size(); ++out) + { + auto push = loco::push_node(graph, out); + auto name = graph->outputs()->at(out)->name(); + + gd._push_to_name[push] = name; + } +} + +#include <stdex/Memory.h> + +#include <cassert> + +namespace +{ + +class TFLTensorIndexAnnotation final : public loco::NodeAnnotation +{ +public: + TFLTensorIndexAnnotation(const TFLTensorIndex &index) : _index{index} + { + // DO NOTHING + } + +public: + const TFLTensorIndex &index(void) const { return _index; } + +private: + TFLTensorIndex _index; +}; + +} // namespace + +void set_tensor_index(loco::Node *node, const TFLTensorIndex &tensor_id) +{ + assert(node->annot<TFLTensorIndexAnnotation>() == nullptr); + node->annot(stdex::make_unique<TFLTensorIndexAnnotation>(tensor_id)); +} + +TFLTensorIndex get_tensor_index(loco::Node *node) +{ + assert(node->annot<TFLTensorIndexAnnotation>() != nullptr); + return node->annot<TFLTensorIndexAnnotation>()->index(); +} + +} // namespace tflite_detail +} // namespace exo diff --git a/compiler/exo/src/TFLite/TFLExporterUtils.h b/compiler/exo/src/TFLite/TFLExporterUtils.h new file mode 100644 index 000000000..dbd7a52fb --- /dev/null +++ b/compiler/exo/src/TFLite/TFLExporterUtils.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TFL_EXPORTER_UTILS_H__ +#define __TFL_EXPORTER_UTILS_H__ + +#include "ExporterUtils.h" + +#include "schema_generated.h" + +#include "Dialect/IR/TFLNodes.h" + +#include <loco.h> + +#include <unordered_map> + +namespace exo +{ +namespace tflite_detail +{ + +struct OpCode +{ + tflite::BuiltinOperator opcode; + + bool operator==(const OpCode &rhs) const { return opcode == rhs.opcode; } +}; + +} // namespace tflite_detail +} // namespace exo + +namespace exo +{ + +tflite::ActivationFunctionType to_tflite_actfunc(locoex::FusedActFunc func); + +} // namespace exo + +namespace std +{ + +template <> struct hash<exo::tflite_detail::OpCode> +{ + size_t operator()(const exo::tflite_detail::OpCode &x) const { return hash<int>()(x.opcode); } +}; + +} // namespace std + +namespace exo +{ +namespace tflite_detail +{ + +/** + * @breif Record the information of T/F Lite SubGraph and its mapping to loco + */ +struct SubGraphContext +{ + /// @brief SubGraph input tensor id + std::vector<int32_t> _inputs; + /// @brief SubGraph output tensor id + std::vector<int32_t> _outputs; +}; + +// Prerequisites for tflite::Model object creation +struct SerializedModelData final : public SubGraphContext +{ + SerializedModelData() = default; + SerializedModelData(const SerializedModelData &) = delete; + + std::unordered_map<OpCode, uint32_t> _operator_codes; + std::unordered_map<OpCode, std::string> _custom_operator_codes; + std::vector<flatbuffers::Offset<tflite::Operator>> _operators; + std::vector<flatbuffers::Offset<tflite::Tensor>> _tensors; + std::vector<flatbuffers::Offset<tflite::Buffer>> _buffers; + + // Graph input and output names + std::unordered_map<loco::Pull *, std::string> _pull_to_name; + std::unordered_map<loco::Push *, std::string> _push_to_name; + + /** + * @brief if opcode is not registered in table of opcodes add it + * @param builtin_code + * @return idx of opcode in table of opcodes (see schema) + */ + uint32_t registerBuiltinOpcode(tflite::BuiltinOperator builtin_code); + uint32_t registerCustomOpcode(const std::string &custom_op); +}; + +tflite::Padding getOpPadding(const loco::Padding2D *pad, const loco::Stride<2> *stride, + const ShapeDescription &ifm, const ShapeDescription &ofm); +tflite::Padding getOpPadding(const locoex::Padding pad); + +/// @brief Register graph input and output names to SerializedModelData +void registerGraphIOName(loco::Graph *graph, SerializedModelData &gd); + +using TFLTensorIndex = int32_t; + +void set_tensor_index(loco::Node *node, const TFLTensorIndex &tensor_id); +TFLTensorIndex get_tensor_index(loco::Node *node); + +} // namespace tflite_detail +} // namespace exo + +#endif // __TFL_EXPORTER_UTILS_H__ diff --git a/compiler/exo/src/TFLite/TFLExporterUtils.test.cpp b/compiler/exo/src/TFLite/TFLExporterUtils.test.cpp new file mode 100644 index 000000000..d19f87d25 --- /dev/null +++ b/compiler/exo/src/TFLite/TFLExporterUtils.test.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLExporterUtils.h" + +#include <gtest/gtest.h> + +using namespace exo::tflite_detail; + +TEST(ExporterUtilsTests, getOpPadding) +{ + loco::Padding2D pad; + loco::Stride<2> stride; + exo::ShapeDescription ifm; + exo::ShapeDescription ofm; + + ifm._dims.resize(4); + ofm._dims.resize(4); + + // VALID padding + { + pad.top(0); + pad.bottom(0); + pad.left(0); + pad.right(0); + + stride.vertical(2); + stride.horizontal(2); + + ifm._dims[1] = 5; + ifm._dims[2] = 5; + + ofm._dims[1] = 2; + ofm._dims[2] = 2; + + ASSERT_EQ(getOpPadding(&pad, &stride, ifm, ofm), tflite::Padding_VALID); + } + + // SAME padding + { + pad.top(1); + pad.bottom(1); + pad.left(1); + pad.right(1); + + stride.vertical(2); + stride.horizontal(2); + + ifm._dims[1] = 5; + ifm._dims[2] = 5; + + ofm._dims[1] = 3; + ofm._dims[2] = 3; + + ASSERT_EQ(getOpPadding(&pad, &stride, ifm, ofm), tflite::Padding_SAME); + } + + // Custom padding 1 - Not supported by tflite + { + pad.top(2); + pad.bottom(0); + pad.left(1); + pad.right(1); + + stride.vertical(2); + stride.horizontal(2); + + ifm._dims[1] = 5; + ifm._dims[2] = 5; + + ofm._dims[1] = 3; + ofm._dims[2] = 3; + + ASSERT_ANY_THROW(getOpPadding(&pad, &stride, ifm, ofm)); + } + + // Custom padding 2 - Not supported by tflite + { + pad.top(2); + pad.bottom(2); + pad.left(2); + pad.right(2); + + stride.vertical(2); + stride.horizontal(2); + + ifm._dims[1] = 5; + ifm._dims[2] = 5; + + ofm._dims[1] = 4; + ofm._dims[2] = 4; + + ASSERT_ANY_THROW(getOpPadding(&pad, &stride, ifm, ofm)); + } +} diff --git a/compiler/exo/src/TFLite/TFLOperationExporter.cpp b/compiler/exo/src/TFLite/TFLOperationExporter.cpp new file mode 100644 index 000000000..79b5b6287 --- /dev/null +++ b/compiler/exo/src/TFLite/TFLOperationExporter.cpp @@ -0,0 +1,1199 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLOperationExporter.h" +#include "TFLExporterUtils.h" +#include "ShapeInference.h" + +#include "Dialect/IR/TFLNode.h" +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +#include "Check.h" + +#include <loco/IR/CanonicalNode.h> +#include <loco/IR/CanonicalNodeVisitor.h> +#include <loco/Service/ShapeInference.h> +#include <locoex/COpCall.h> + +#include <oops/InternalExn.h> + +#include <flatbuffers/flexbuffers.h> + +using namespace flatbuffers; +using namespace tflite; + +namespace +{ + +using namespace exo; +using namespace exo::tflite_detail; + +class OperationExporter final : public locoex::TFLNodeMutableVisitor<void>, + public loco::CanonicalNodeMutableVisitor<void> +{ +public: + OperationExporter(FlatBufferBuilder &fbb, SerializedModelData &ctx) : builder{fbb}, gd{ctx} + { + // DO NOTHING + } + +public: + // FOR TFLNodes + void visit(locoex::TFLAdd *) final; + void visit(locoex::TFLAveragePool2D *) final; + void visit(locoex::TFLConcatenation *) final; + void visit(locoex::TFLConst *) final{/* skip, everything is done in exportOpDefinedTensors */}; + void visit(locoex::TFLConv2D *) final; + void visit(locoex::TFLDepthwiseConv2D *) final; + void visit(locoex::TFLDiv *) final; + void visit(locoex::TFLFullyConnected *) final; + void visit(locoex::TFLMaximum *) final; + void visit(locoex::TFLMaxPool2D *) final; + void visit(locoex::TFLMean *) final; + void visit(locoex::TFLMul *) final; + void visit(locoex::TFLRelu *) final; + void visit(locoex::TFLRelu6 *) final; + // TODO TFLReshape + void visit(locoex::TFLRsqrt *) final; + // TODO TFLSoftmax + void visit(locoex::TFLSqrt *) final; + void visit(locoex::TFLSquaredDifference *) final; + void visit(locoex::TFLSub *) final; + // TODO TFLTanh + void visit(locoex::TFLTranspose *) final; + void visit(locoex::TFLTransposeConv *) final; + + // FOR canonical nodes. These will be removed later + void visit(loco::ReLU *) final; + void visit(loco::ReLU6 *) final; + void visit(loco::Tanh *) final; + void visit(loco::Push *) final { /* DO NOTHING */} + void visit(loco::Pull *) final { /* DO NOTHING */} + void visit(loco::FeatureEncode *) final; + void visit(loco::FeatureDecode *) final; + void visit(loco::FilterEncode *) final; + void visit(loco::DepthwiseFilterEncode *) final; + void visit(loco::ConstGen *) final { /* skip, everything is done in exportOpDefinedTensors */} + void visit(loco::MaxPool2D *) final; + void visit(loco::AvgPool2D *) final; + void visit(loco::Conv2D *) final; + void visit(loco::TransposedConv2D *) final; + void visit(loco::DepthwiseConv2D *) final; + void visit(loco::TensorConcat *) final; + void visit(loco::TensorReduce *) final; + void visit(loco::TensorSoftmax *) final; + void visit(loco::BiasEncode *) final; + void visit(loco::TensorBiasAdd *) final; + void visit(loco::FeatureBiasAdd *) final; + void visit(loco::EltwiseAdd *) final; + void visit(loco::EltwiseMax *) final; + void visit(loco::EltwiseMul *) final; + void visit(loco::EltwiseSub *) final; + void visit(loco::EltwiseDiv *) final; + void visit(loco::EltwiseSqrt *) final; + void visit(loco::FixedReshape *) final; + void visit(loco::TensorBroadcast *) final; + void visit(loco::TensorConstantPad *) final; + + void visit(locoex::COpCall *); + +private: + /** + * @brief Exports TFLMaxPool2D or TFLAveragePool2D + * + * @note TFLPool2D should be one of TFLMaxPool2D or TFLAveragePool2D + */ + template <class TFLPool2D> + void export_pool_2d(TFLPool2D *node, tflite::BuiltinOperator builtin_op); + +private: + FlatBufferBuilder &builder; + SerializedModelData &gd; +}; + +void OperationExporter::visit(locoex::TFLAdd *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateAddOptions(builder, to_tflite_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_AddOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLAveragePool2D *node) +{ + export_pool_2d<locoex::TFLAveragePool2D>(node, tflite::BuiltinOperator_AVERAGE_POOL_2D); +} + +void OperationExporter::visit(locoex::TFLConcatenation *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONCATENATION); + std::vector<int32_t> inputs_vec; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + + for (uint32_t i = 0; i < node->numValues(); ++i) + inputs_vec.push_back(get_tensor_index(node->values(i))); + + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateConcatenationOptions(builder, node->axis(), + to_tflite_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_ConcatenationOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLConv2D *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONV_2D); + + // Make input, output and options for operator + std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->filter()), + get_tensor_index(node->bias())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + tflite::Padding padding = getOpPadding(node->padding()); + auto options = CreateConv2DOptions(builder, padding, node->stride()->w(), node->stride()->h(), + to_tflite_actfunc(node->fusedActivationFunction())); + + // Make CONV_2D operator + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_Conv2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLDepthwiseConv2D *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_DEPTHWISE_CONV_2D); + + // Make input, output and options for operator + std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->filter()), + get_tensor_index(node->bias())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + tflite::Padding padding = getOpPadding(node->padding()); + auto options = CreateDepthwiseConv2DOptions(builder, padding, node->stride()->w(), + node->stride()->h(), node->depthMultiplier(), + to_tflite_actfunc(node->fusedActivationFunction())); + + // Make DEPTHWISE_CONV_2D operator + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_DepthwiseConv2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLDiv *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_DIV); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateDivOptions(builder, to_tflite_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_DivOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLFullyConnected *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_FULLY_CONNECTED); + + // Make input, output and options for operator + std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), + get_tensor_index(node->weights()), + get_tensor_index(node->bias())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = + CreateFullyConnectedOptions(builder, to_tflite_actfunc(node->fusedActivationFunction())); + + // Make FULLY_CONNECTED operator + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_FullyConnectedOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLMaximum *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MAXIMUM); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateMaximumMinimumOptions(builder); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_MaximumMinimumOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLMaxPool2D *node) +{ + export_pool_2d<locoex::TFLMaxPool2D>(node, tflite::BuiltinOperator_MAX_POOL_2D); +} + +void OperationExporter::visit(locoex::TFLMean *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MEAN); + std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), + get_tensor_index(node->reduction_indices())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateReducerOptions(builder, node->keep_dims()); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_ReducerOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLMul *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MUL); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateMulOptions(builder, to_tflite_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_MulOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLRelu *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU); + std::vector<int32_t> inputs_vec{get_tensor_index(node->features())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLRelu6 *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU6); + std::vector<int32_t> inputs_vec{get_tensor_index(node->features())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +// TODO TFLReshape + +void OperationExporter::visit(locoex::TFLRsqrt *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RSQRT); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +// TODO TFLSoftmax + +void OperationExporter::visit(locoex::TFLSqrt *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_SQRT); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLSquaredDifference *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_SQUARED_DIFFERENCE); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateSquaredDifferenceOptions(builder); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_SquaredDifferenceOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLSub *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_SUB); + std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateSubOptions(builder, to_tflite_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_SubOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +// TODO TFLTanh + +void OperationExporter::visit(locoex::TFLTranspose *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_TRANSPOSE); + std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0)), get_tensor_index(node->arg(1))}; + std::vector<int32_t> outputs_vec{get_tensor_index(node)}; + + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateTransposeOptions(builder); + + auto op_offset = + CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions::BuiltinOptions_TransposeOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(locoex::TFLTransposeConv *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_TRANSPOSE_CONV); + + // Make input, output and options for operator + std::vector<int32_t> inputs_vec{get_tensor_index(node->inputSizes()), + get_tensor_index(node->filter()), + get_tensor_index(node->outBackprop())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + tflite::Padding padding = getOpPadding(node->padding()); + auto options = + CreateTransposeConvOptions(builder, padding, node->stride()->w(), node->stride()->h()); + + // Make TRANSPOSE_CONV operator + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_TransposeConvOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +template <class TFLPool2D> +void OperationExporter::export_pool_2d(TFLPool2D *node, tflite::BuiltinOperator builtin_op) +{ + EXO_ASSERT(builtin_op == tflite::BuiltinOperator_MAX_POOL_2D || + builtin_op == tflite::BuiltinOperator_AVERAGE_POOL_2D, + "should be maxpool or avgpool"); + EXO_ASSERT(node->padding() != locoex::Padding::UNDEFINED, "Padding is not set"); + + uint32_t op_idx = gd.registerBuiltinOpcode(builtin_op); + std::vector<int32_t> inputs_vec{get_tensor_index(node->value())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + + tflite::Padding padding = getOpPadding(node->padding()); + + auto options = CreatePool2DOptions(builder, padding, node->stride()->w(), node->stride()->h(), + node->filter()->w(), node->filter()->h(), + to_tflite_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_Pool2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::ReLU *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU); + std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::ReLU6 *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU6); + std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::Tanh *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_TANH); + std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::MaxPool2D *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MAX_POOL_2D); + std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + tflite::Padding padding = getOpPadding( + node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node)); + auto options = CreatePool2DOptions(builder, padding, node->stride()->horizontal(), + node->stride()->vertical(), node->window()->horizontal(), + node->window()->vertical()); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_Pool2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::AvgPool2D *node) +{ + // TFlite only support Valid convention of average pooling + assert(node->convention() == loco::AvgPool2D::Convention::Valid); + + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_AVERAGE_POOL_2D); + std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + tflite::Padding padding = getOpPadding( + node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node)); + auto options = CreatePool2DOptions(builder, padding, node->stride()->horizontal(), + node->stride()->vertical(), node->window()->horizontal(), + node->window()->vertical()); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_Pool2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::Conv2D *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONV_2D); + + // Third input of CONV_2D of tflite should be bias. We will make (and register to gd) dummy zero + // bias. Bias would be rank 1, have size of output kernel count, and have all zero values, i.e. + // zero bias. + auto *ker = dynamic_cast<loco::FilterEncode *>(node->ker()); + assert(ker); + int32_t bias_vec_size = ShapeInference::get(ker)._dims[0]; // output kernel count + + auto bias_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{bias_vec_size}); + size_t raw_bias_vec_size = bias_vec_size * sizeof(int32_t); + + std::vector<float> bias_vec_data(bias_vec_size); // initialized as zero vector + + auto bias_vec_offset = + builder.CreateVector(reinterpret_cast<uint8_t *>(bias_vec_data.data()), raw_bias_vec_size); + + auto bias_buffer_offset = CreateBuffer(builder, bias_vec_offset); + + const auto bias_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + + gd._buffers.push_back(bias_buffer_offset); + + auto bias_tensor_id = static_cast<int32_t>(gd._tensors.size()); + auto name_offset = builder.CreateString("t_" + std::to_string(bias_tensor_id)); + + auto bias_tensor_offset = + CreateTensor(builder, bias_vec_shape_offset, TensorType_FLOAT32, bias_buffer_id, name_offset); + gd._tensors.push_back(bias_tensor_offset); + + // Make input, output and options for operator + std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm()), get_tensor_index(node->ker()), + bias_tensor_id}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + tflite::Padding padding = getOpPadding( + node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node)); + auto options = CreateConv2DOptions(builder, padding, node->stride()->horizontal(), + node->stride()->vertical()); + + // Make CONV_2D operator + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_Conv2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::TransposedConv2D *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_TRANSPOSE_CONV); + + // TRANSPOSE_CONV's first input is output shape array. + const int32_t outshape_vec_size = 4; + auto outshape_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{outshape_vec_size}); + size_t raw_outshape_vec_size = outshape_vec_size * sizeof(int32_t); + + std::vector<int32_t> outshape_vec_data(outshape_vec_size); + { + // Copy inferred output shape of node + auto out_feature_shape = loco::shape_get(node).as<loco::FeatureShape>(); + + // Feature tensor in TFlite is NHWC + outshape_vec_data.at(0) = out_feature_shape.count().value(); + outshape_vec_data.at(1) = out_feature_shape.height().value(); + outshape_vec_data.at(2) = out_feature_shape.width().value(); + outshape_vec_data.at(3) = out_feature_shape.depth().value(); + } + + auto outshape_vec_offset = builder.CreateVector( + reinterpret_cast<uint8_t *>(outshape_vec_data.data()), raw_outshape_vec_size); + + auto outshape_buffer_offset = CreateBuffer(builder, outshape_vec_offset); + + const auto outshape_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + + gd._buffers.push_back(outshape_buffer_offset); + + auto outshape_tensor_id = static_cast<int32_t>(gd._tensors.size()); + auto name_offset = builder.CreateString("t_" + std::to_string(outshape_tensor_id)); + + auto outshape_tensor_offset = CreateTensor(builder, outshape_vec_shape_offset, TensorType_INT32, + outshape_buffer_id, name_offset); + gd._tensors.push_back(outshape_tensor_offset); + + // Make input, output and options for operator + std::vector<int32_t> inputs_vec{outshape_tensor_id, get_tensor_index(node->ker()), + get_tensor_index(node->ifm())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + // NOTE input and output is inversed to use this function + tflite::Padding padding = getOpPadding(node->pad(), node->stride(), ShapeInference::get(node), + ShapeInference::get(node->ifm())); + auto options = CreateTransposeConvOptions(builder, padding, node->stride()->horizontal(), + node->stride()->vertical()); + + // Make TRANSPOSE_CONV operator + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_TransposeConvOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::DepthwiseConv2D *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_DEPTHWISE_CONV_2D); + + // Third input of DEPTHWISE_CONV2D of tflite should be bias. We will make (and register to gd) + // dummy zero bias. Bias would be rank 1, have size of output kernel count, and have all zero + // values, i.e. zero bias. + auto *ker = dynamic_cast<loco::DepthwiseFilterEncode *>(node->ker()); + assert(ker); + + int32_t bias_vec_size = ShapeInference::get(ker)._dims[3]; // output_size(C*M) + auto bias_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{bias_vec_size}); + + size_t raw_bias_vec_size = bias_vec_size * sizeof(int32_t); + std::vector<float> bias_vec_data(bias_vec_size); + auto bias_vec_offset = + builder.CreateVector(reinterpret_cast<uint8_t *>(bias_vec_data.data()), raw_bias_vec_size); + + auto bias_buffer_offset = CreateBuffer(builder, bias_vec_offset); + + const auto bias_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + + gd._buffers.push_back(bias_buffer_offset); + + auto bias_tensor_id = static_cast<int32_t>(gd._tensors.size()); + auto name_offset = builder.CreateString("t_" + std::to_string(bias_tensor_id)); + + auto bias_tensor_offset = + CreateTensor(builder, bias_vec_shape_offset, TensorType_FLOAT32, bias_buffer_id, name_offset); + gd._tensors.push_back(bias_tensor_offset); + + std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm()), get_tensor_index(node->ker()), + bias_tensor_id}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + tflite::Padding padding = getOpPadding( + node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node)); + + int32_t ifm_channel_size = ShapeInference::get(node->ifm())._dims[3]; + // multiplier = bias_vec_size(output_size)/ifm_channel_size + auto options = + CreateDepthwiseConv2DOptions(builder, padding, node->stride()->horizontal(), + node->stride()->vertical(), bias_vec_size / ifm_channel_size); + + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_DepthwiseConv2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::TensorReduce *node) +{ + uint32_t op_idx; + + switch (node->func()) + { + case loco::ReduceFunc::Mean: + op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MEAN); + break; + + // TODO Support more reduce type operation + default: + INTERNAL_EXN_V("Not supported reduce type", oops::to_uint32(node->func())); + } + + // Create a vector for axes data + std::vector<int32_t> axes_vec; + auto rank = ShapeInference::get(node->input())._dims.size(); + for (uint32_t i = 0; i < rank; ++i) + if (node->axes()->defined(i)) + axes_vec.push_back(i); + + int32_t axes_vec_size = axes_vec.size(); + auto axes_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{axes_vec_size}); + + size_t raw_axes_vec_size = axes_vec_size * sizeof(int32_t); + auto axes_vec_offset = + builder.CreateVector(reinterpret_cast<uint8_t *>(axes_vec.data()), raw_axes_vec_size); + + auto axes_buffer_offset = CreateBuffer(builder, axes_vec_offset); + + const auto axes_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + + gd._buffers.push_back(axes_buffer_offset); + + auto axes_tensor_id = static_cast<int32_t>(gd._tensors.size()); + auto name_offset = builder.CreateString("t_" + std::to_string(axes_tensor_id)); + + auto axes_tensor_offset = + CreateTensor(builder, axes_vec_shape_offset, TensorType_INT32, axes_buffer_id, name_offset); + gd._tensors.push_back(axes_tensor_offset); + + std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), axes_tensor_id}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateReducerOptions(builder, true); // true is for keep_dims option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_ReducerOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::TensorSoftmax *node) +{ + // TODO Support when the input rank of TensorSoftmax is not 2 + assert(ShapeInference::get(node->input())._dims.size() == 2); + + // NOTE TFLite only accepts axis when the value is last dimension + assert(node->axis() == ShapeInference::get(node->input())._dims.size() - 1); + + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_SOFTMAX); + std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateSoftmaxOptions(builder, 1.0f); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_SoftmaxOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +/// @brief Export given node into identity, i.e. CONCATENATION with one input +template <typename NodeT> +void exportIdentity(NodeT *node, FlatBufferBuilder &builder, SerializedModelData &gd) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONCATENATION); + std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0))}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateConcatenationOptions(builder); // use dummy 0 axis and NONE activation + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_ConcatenationOptions, options.Union()); + + gd._operators.push_back(op_offset); +} + +/// @brief Export loco nodes as TRANSPOSE +void exportAsTranspose(loco::Node *node, FlatBufferBuilder &builder, + std::vector<int32_t> &perm_vec_data, SerializedModelData &gd) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_TRANSPOSE); + + auto options = CreateTransposeOptions(builder); + + // Create constant tensor with perm vector + constexpr int perm_vec_size = 4; + assert(perm_vec_data.size() == perm_vec_size); + auto perm_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{perm_vec_size}); + constexpr size_t raw_perm_vec_size = perm_vec_size * sizeof(int32_t); + + auto perm_vec_offset = + builder.CreateVector(reinterpret_cast<uint8_t *>(perm_vec_data.data()), raw_perm_vec_size); + + auto perm_buffer_offset = CreateBuffer(builder, perm_vec_offset); + + const auto perm_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + + gd._buffers.push_back(perm_buffer_offset); + + auto perm_tensor_id = static_cast<int32_t>(gd._tensors.size()); + auto name_offset = builder.CreateString("t_" + std::to_string(perm_tensor_id)); + + auto perm_tensor_offset = + CreateTensor(builder, perm_vec_shape_offset, TensorType_INT32, perm_buffer_id, name_offset); + gd._tensors.push_back(perm_tensor_offset); + + // Create permutation node + + std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0)), perm_tensor_id}; + std::vector<int32_t> outputs_vec{get_tensor_index(node)}; + + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + + constexpr auto options_type = tflite::BuiltinOptions::BuiltinOptions_TransposeOptions; + + auto transpose_offset = + CreateOperator(builder, op_idx, inputs, outputs, options_type, options.Union()); + gd._operators.push_back(transpose_offset); +} + +void OperationExporter::visit(loco::FeatureEncode *node) +{ + auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Feature> *>(node->encoder()); + auto perm = encoder->perm(); + + if (isNHWC(perm)) + { + // Note that tflite represents feature as NHWC + exportIdentity(node, builder, gd); + } + else + { + std::vector<int32_t> perm_vec_data(4); + perm_vec_data[0] = perm->axis(loco::FeatureAxis::Count); + perm_vec_data[1] = perm->axis(loco::FeatureAxis::Height); + perm_vec_data[2] = perm->axis(loco::FeatureAxis::Width); + perm_vec_data[3] = perm->axis(loco::FeatureAxis::Depth); + + exportAsTranspose(node, builder, perm_vec_data, gd); + } +} + +void OperationExporter::visit(loco::FeatureDecode *node) +{ + auto decoder = dynamic_cast<loco::PermutingDecoder<loco::Domain::Feature> *>(node->decoder()); + auto perm = decoder->perm(); + + if (isNHWC(perm)) + { + // Note that tflite represents feature as NHWC + exportIdentity(node, builder, gd); + } + else + { + std::vector<int32_t> perm_vec_data(4); + perm_vec_data[perm->axis(loco::FeatureAxis::Count)] = 0; + perm_vec_data[perm->axis(loco::FeatureAxis::Height)] = 1; + perm_vec_data[perm->axis(loco::FeatureAxis::Width)] = 2; + perm_vec_data[perm->axis(loco::FeatureAxis::Depth)] = 3; + + exportAsTranspose(node, builder, perm_vec_data, gd); + } +} + +void OperationExporter::visit(loco::FilterEncode *node) +{ + auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Filter> *>(node->encoder()); + auto perm = encoder->perm(); + + if (isNHWC(perm)) + { + // Note that tflite represents filter as NHWC + exportIdentity(node, builder, gd); + } + else + { + std::vector<int32_t> perm_vec_data(4); + // NOTE In tflite, all tensors means NHWC, so 0 = N, 1 = H, 2 = W, 3 = C + perm_vec_data[0] = perm->axis(loco::FilterAxis::Count); + perm_vec_data[1] = perm->axis(loco::FilterAxis::Height); + perm_vec_data[2] = perm->axis(loco::FilterAxis::Width); + perm_vec_data[3] = perm->axis(loco::FilterAxis::Depth); + + exportAsTranspose(node, builder, perm_vec_data, gd); + } +} + +void exportAsReshape(loco::Node *node, FlatBufferBuilder &builder, + std::vector<int32_t> &new_shape_vec, SerializedModelData &gd) +{ + // NOTE TFLite has two ways to get new shape paramter, + // one is by attribute 'new_shape' and the other is by input 'shape'. + // Therefore TFLite interpreter calculates Reshape operation correctly + // if one of them is valid. + // However, since NN runtime usually get new shape parameter by input 'shape', + // passing new shape only by attribute can cause some problems. + // Of course, the opposite situation can be occurred in the future. + // To prevent those problems, we pass new shape parameter not only by attribute + // but also by input. + + auto input_shape_shape_vec_offset = + builder.CreateVector(std::vector<int32_t>{(int32_t)new_shape_vec.size()}); + + size_t input_shape_vec_size = new_shape_vec.size() * sizeof(int32_t); + auto input_shape_input_vec_offset = + builder.CreateVector(reinterpret_cast<uint8_t *>(new_shape_vec.data()), input_shape_vec_size); + auto input_shape_buffer_offset = CreateBuffer(builder, input_shape_input_vec_offset); + + const auto input_shape_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + gd._buffers.push_back(input_shape_buffer_offset); + + auto input_shape_tensor_id = static_cast<int32_t>(gd._tensors.size()); + auto name_offset = builder.CreateString("t_" + std::to_string(input_shape_tensor_id)); + auto input_shape_tensor_offset = CreateTensor( + builder, input_shape_shape_vec_offset, TensorType_INT32, input_shape_buffer_id, name_offset); + gd._tensors.push_back(input_shape_tensor_offset); + + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RESHAPE); + + std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0)), input_shape_tensor_id}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + + auto new_shape_vec_offset = builder.CreateVector(new_shape_vec); + auto options = CreateReshapeOptions(builder, new_shape_vec_offset); + + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_ReshapeOptions, options.Union()); + + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::DepthwiseFilterEncode *node) +{ + auto ker = node->input(); // [H, W, C, M] + + // tflite represents filter as [1, H, W, C*M] where M is multiplier. + std::vector<int32_t> new_shape_vec(4); + new_shape_vec[0] = 1; + new_shape_vec[1] = ShapeInference::get(ker)._dims[0]; + new_shape_vec[2] = ShapeInference::get(ker)._dims[1]; + new_shape_vec[3] = ShapeInference::get(ker)._dims[2] * ShapeInference::get(ker)._dims[3]; + + exportAsReshape(node, builder, new_shape_vec, gd); +} + +void OperationExporter::visit(loco::BiasAdd<loco::Domain::Tensor> *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD); + std::vector<int32_t> inputs_vec{get_tensor_index(node->value()), get_tensor_index(node->bias())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateAddOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_AddOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::FeatureBiasAdd *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD); + std::vector<int32_t> inputs_vec{get_tensor_index(node->value()), get_tensor_index(node->bias())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateAddOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_AddOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +/// @brief Export CONCATENATION of **TWO** tensors only +void OperationExporter::visit(loco::TensorConcat *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONCATENATION); + std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateConcatenationOptions(builder, node->axis()); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_ConcatenationOptions, options.Union()); + + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::BiasEncode *encode) { exportIdentity(encode, builder, gd); } + +void OperationExporter::visit(loco::EltwiseAdd *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD); + std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateAddOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_AddOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::EltwiseMax *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MAXIMUM); + std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateMaximumMinimumOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_MaximumMinimumOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::EltwiseMul *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MUL); + std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateMulOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_MulOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::EltwiseSub *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_SUB); + std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateSubOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_SubOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::EltwiseDiv *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_DIV); + std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateDivOptions(builder); // dummy option + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_DivOptions, options.Union()); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::EltwiseSqrt *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_SQRT); + std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::FixedReshape *node) +{ + std::vector<int32_t> new_shape_vec; + for (uint32_t axis = 0; axis < node->rank(); ++axis) + { + assert(node->dim(axis).known()); + new_shape_vec.push_back(node->dim(axis).value()); + } + + exportAsReshape(node, builder, new_shape_vec, gd); +} + +void OperationExporter::visit(loco::TensorBroadcast *) +{ + INTERNAL_EXN("TensorBroadcast should not exist in the graph"); +} + +void OperationExporter::visit(loco::TensorConstantPad *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_PAD); + + // make padding attribute an input + auto padding = node->padding(); + // get padding vector size + int32_t padding_vec_size = padding->rank(); + // get byte size of vector + size_t padding_vec_byte_size = padding_vec_size * sizeof(int32_t) * 2; // [rank, 2] + // create vector for data + std::vector<int32_t> padding_vec_data(padding_vec_size * 2); + // set data + for (int32_t i = 0; i < padding_vec_size; i++) + { + padding_vec_data.at(i * 2) = padding->front(i); + padding_vec_data.at(i * 2 + 1) = padding->back(i); + } + // create FlatBuffer vector + auto padding_vec_ptr = builder.CreateVector(reinterpret_cast<uint8_t *>(padding_vec_data.data()), + padding_vec_byte_size); + + // create buffer + auto padding_buffer_ptr = CreateBuffer(builder, padding_vec_ptr); + // get buffer id + const auto padding_buffer_id = static_cast<uint32_t>(gd._buffers.size()); + + gd._buffers.push_back(padding_buffer_ptr); + + // create padding shape vector + auto padding_shape_vec_ptr = builder.CreateVector(std::vector<int32_t>{padding_vec_size, 2}); + // create tensor + auto padding_tensor_ptr = + CreateTensor(builder, padding_shape_vec_ptr, TensorType_INT32, padding_buffer_id); + // get tensor id + const auto padding_tensor_id = static_cast<int32_t>(gd._tensors.size()); + + gd._tensors.push_back(padding_tensor_ptr); + + std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), padding_tensor_id}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + +inline flatbuffers::Offset<flatbuffers::Vector<uint8_t>> +CreateCOpCallOptions(flatbuffers::FlatBufferBuilder &fbb, locoex::COpCall *copCall) +{ + // read attrs in FlexBuffer format and pass them to FlatBuffer builder + flexbuffers::Builder flexbuf; + { + size_t map_start = flexbuf.StartMap(); + + // Note: among attrs of COpCall, 'op' and 'name' won't be included into tflite file + auto names = copCall->attr_names(); + for (auto name : names) + { + if (auto int_val = copCall->attr<locoex::COpAttrType::Int>(name)) + flexbuf.Int(name.c_str(), int_val->val()); + else if (auto float_val = copCall->attr<locoex::COpAttrType::Float>(name)) + flexbuf.Float(name.c_str(), float_val->val()); + else + // TODO Support more attribute types + INTERNAL_EXN("Not supported type while writing flexbuffer"); + } + + flexbuf.EndMap(map_start); + flexbuf.Finish(); + } + + auto offset = fbb.CreateVector(flexbuf.GetBuffer()); + + return offset; +} + +void OperationExporter::visit(locoex::COpCall *call) +{ + // Registering this custom op name into tflite Operator Codes table + uint32_t op_idx = gd.registerCustomOpcode(call->op()); + + std::vector<int32_t> inputs_vec; + { + inputs_vec.resize(call->arity()); + for (uint32_t i = 0; i < call->arity(); i++) + inputs_vec[i] = get_tensor_index(call->arg(i)); + } + + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(call))}; + + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + + auto custom_options = CreateCOpCallOptions(builder, call); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_NONE, // builtin_options_type + 0, // built-in option + custom_options, // custom options + tflite::CustomOptionsFormat_FLEXBUFFERS); + + gd._operators.push_back(op_offset); +} + +void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, + SerializedModelData &data) +{ + // TODO Use explicit tagging to prevent possible mistake + auto isNoOp = [](loco::Node *node) { + if (node->arity() == 1) + { + assert(node->arg(0) != nullptr); + return get_tensor_index(node) == get_tensor_index(node->arg(0)); + } + return false; + }; + + if (isNoOp(node)) + { + // Skip if a given node is marked as NoOp (op with no effect) before + return; + } + + if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node)) + { // TODO Consider removing this later + OperationExporter exporter{builder, data}; + canonical_node->accept(&exporter); + } + else if (auto tfl_node = dynamic_cast<locoex::TFLNode *>(node)) + { + OperationExporter exporter{builder, data}; + tfl_node->accept(&exporter); + } + else if (dynamic_cast<locoex::COpNode *>(node)) + { + OperationExporter exporter{builder, data}; + exporter.visit(dynamic_cast<locoex::COpCall *>(node)); + } + else + { + assert(false && "unsupported node found"); + } +} + +} // namespace + +namespace exo +{ +namespace tflite_detail +{ + +void exportNodes(loco::Graph *g, FlatBufferBuilder &builder, SerializedModelData &gd) +{ + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + { + exportNode(node, builder, gd); + } +} + +} // namespace tflite_detail +} // namespace exo diff --git a/compiler/exo/src/TFLite/TFLOperationExporter.h b/compiler/exo/src/TFLite/TFLOperationExporter.h new file mode 100644 index 000000000..60f2b5eb2 --- /dev/null +++ b/compiler/exo/src/TFLite/TFLOperationExporter.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TFL_OPERATION_EXPORTER_H__ +#define __TFL_OPERATION_EXPORTER_H__ + +#include "TFLExporterUtils.h" + +#include <loco/IR/Graph.h> + +namespace exo +{ +namespace tflite_detail +{ + +/** + * @brief create Operators corresponding to model nodes + * @param nodes container with nodes + * @param gd information about serializer parts of model + */ +void exportNodes(loco::Graph *g, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &gd); + +} // namespace tflite_detail +} // namespace exo + +#endif // __TFL_OPERATION_EXPORTER_H__ diff --git a/compiler/exo/src/TFLite/TFLTensorExporter.cpp b/compiler/exo/src/TFLite/TFLTensorExporter.cpp new file mode 100644 index 000000000..66854ef87 --- /dev/null +++ b/compiler/exo/src/TFLite/TFLTensorExporter.cpp @@ -0,0 +1,249 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLTensorExporter.h" +#include "TFLTypeInference.h" +#include "ShapeInference.h" + +// TODO Fix include style +#include "loco/IR/Algorithm.h" +#include "loco/IR/CanonicalNode.h" +#include "loco/IR/CanonicalNodeVisitor.h" +#include "loco/IR/DataTypeTraits.h" + +#include "Dialect/IR/TFLNodes.h" + +#include <oops/InternalExn.h> + +using namespace tflite; +using namespace flatbuffers; + +namespace +{ + +using namespace exo; +using namespace exo::tflite_detail; + +class TFLTensorInfo +{ +public: + TFLTensorInfo() = default; + +public: + void name(const std::string &name) { _name = name; } + const std::string &name(void) const { return _name; } + +public: + const tflite::TensorType &dtype(void) const { return _dtype; } + void dtype(const tflite::TensorType &dtype) { _dtype = dtype; } + + const ShapeDescription &shape(void) const { return _shape; } + void shape(const ShapeDescription &shape) { _shape = shape; } + +public: + locoex::TFLConst *tfl_content(void) const { return _tfl_content; } + void tfl_content(locoex::TFLConst *c) { _tfl_content = c; } + +private: + std::string _name; + + tflite::TensorType _dtype; + ShapeDescription _shape; + + // TODO Find a better design + loco::ConstGen *_content = nullptr; // TODO deprecate + locoex::TFLConst *_tfl_content = nullptr; +}; + +using TFLTensorContext = std::vector<TFLTensorInfo>; + +struct NoOpDetector final : public loco::CanonicalNodeMutableVisitor<bool> +{ + bool visit(loco::BiasEncode *) final + { + // BiasEncode is always noop + return true; + } + + bool visit(loco::FilterEncode *node) final + { + auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Filter> *>(node->encoder()); + auto perm = encoder->perm(); + + return isNHWC(perm); + } + + bool visit(loco::FeatureEncode *node) final + { + auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Feature> *>(node->encoder()); + auto perm = encoder->perm(); + return isNHWC(perm); + } + + bool visit(loco::FeatureDecode *node) final + { + auto decoder = dynamic_cast<loco::PermutingDecoder<loco::Domain::Feature> *>(node->decoder()); + auto perm = decoder->perm(); + return isNHWC(perm); + } + + // Return false by default + bool visit(loco::Node *) final { return false; } +}; + +bool isNoOp(loco::Node *node) +{ + if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node)) + { + NoOpDetector d; + return canonical_node->accept(&d); + } + return false; +} + +void allocateTFLiteTensor(loco::Node *node, TFLTensorContext &ctx) +{ + if (isNoOp(node)) + { + assert(node->arity() == 1 && node->arg(0) != nullptr); + set_tensor_index(node, get_tensor_index(node->arg(0))); + return; + } + + auto tensor_index = static_cast<TFLTensorIndex>(ctx.size()); + // TODO Use Graph-level metadata for Input & Output + auto tensor_name = "t_" + std::to_string(tensor_index); + + TFLTensorInfo tensor_info; + + tensor_info.name(tensor_name); + tensor_info.dtype(TypeInference::get(node)); + tensor_info.shape(ShapeInference::get(node)); + + tensor_info.tfl_content(dynamic_cast<locoex::TFLConst *>(node)); + + set_tensor_index(node, tensor_index); + + ctx.emplace_back(tensor_info); +} + +} // namespace + +namespace +{ + +flatbuffers::Offset<Vector<int32_t>> encodeShape(FlatBufferBuilder &builder, + const ShapeDescription &shape) +{ + assert(shape._rank_known && "unknown number of dimensions is not supported"); + return builder.CreateVector(shape._dims); +} + +flatbuffers::Offset<tflite::Buffer> encodeOpBuffer(FlatBufferBuilder &builder) +{ + return CreateBuffer(builder); +} + +template <typename NodeT> +flatbuffers::Offset<tflite::Buffer> encodeOpBuffer(FlatBufferBuilder &builder, NodeT *) +{ + return CreateBuffer(builder); +} + +template <loco::DataType DT> +flatbuffers::Offset<tflite::Buffer> encodeOpBufferByDType(FlatBufferBuilder &builder, + locoex::TFLConst *c) +{ + using NativeType = typename loco::DataTypeImpl<DT>::Type; + + std::vector<NativeType> raw_data; + const uint32_t size = c->size<DT>(); + raw_data.reserve(size); + for (uint32_t i = 0; i < size; ++i) + { + raw_data.push_back(c->at<DT>(i)); + } + const size_t raw_size = size * sizeof(NativeType); + auto array_offset = builder.CreateVector(reinterpret_cast<uint8_t *>(raw_data.data()), raw_size); + return CreateBuffer(builder, array_offset); +} + +template <> +flatbuffers::Offset<tflite::Buffer> encodeOpBuffer(FlatBufferBuilder &builder, locoex::TFLConst *c) +{ + if (c->dtype() == loco::DataType::FLOAT32) + { + return encodeOpBufferByDType<loco::DataType::FLOAT32>(builder, c); + } + else if (c->dtype() == loco::DataType::S32) + { + return encodeOpBufferByDType<loco::DataType::S32>(builder, c); + } + + INTERNAL_EXN_V("Unsupported datatype", oops::to_uint32(c->dtype())); +} + +} // namespace + +namespace exo +{ +namespace tflite_detail +{ + +void exportOpDefinedTensor(const TFLTensorInfo &info, FlatBufferBuilder &builder, + SerializedModelData &gd) +{ + // Create and register output tensor shape + auto shape_offset = encodeShape(builder, info.shape()); + + // encode and register output tensor buffer + auto buffer = info.tfl_content() == nullptr ? encodeOpBuffer(builder) + : encodeOpBuffer(builder, info.tfl_content()); + + auto buffer_id = static_cast<uint32_t>(gd._buffers.size()); + gd._buffers.push_back(buffer); + + auto name_offset = builder.CreateString(info.name()); + auto tensor_offset = CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset, + /*quantization*/ 0, /*is_variable*/ false); + gd._tensors.push_back(tensor_offset); +} + +void exportOpDefinedTensors(loco::Graph *g, FlatBufferBuilder &builder, SerializedModelData &gd) +{ + TFLTensorContext tensor_ctx; + + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + { + allocateTFLiteTensor(node, tensor_ctx); + } + + // add one empty buffer + // note: there's a comment in tflite fbs file + // - Note the 0th entry of this array must be an empty buffer (sentinel). + // - This is a convention so that tensors without a buffer can provide 0 as + // - their buffer. + auto buffer = encodeOpBuffer(builder); + gd._buffers.push_back(buffer); + + for (const auto &tensor_info : tensor_ctx) + { + exportOpDefinedTensor(tensor_info, builder, gd); + } +} + +} // namespace tflite_detail +} // namespace exo diff --git a/compiler/exo/src/TFLite/TFLTensorExporter.h b/compiler/exo/src/TFLite/TFLTensorExporter.h new file mode 100644 index 000000000..97e702665 --- /dev/null +++ b/compiler/exo/src/TFLite/TFLTensorExporter.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TFL_TENSOR_EXPORTER_H__ +#define __TFL_TENSOR_EXPORTER_H__ + +#include "TFLExporterUtils.h" + +#include <loco/IR/Graph.h> + +#include <flatbuffers/flatbuffers.h> + +namespace exo +{ +namespace tflite_detail +{ + +/** + * @brief create Tensors corresponding to results of all nodes in graph + * @param computational graph + * @param gd information about serialized parts of model + */ +void exportOpDefinedTensors(loco::Graph *g, flatbuffers::FlatBufferBuilder &builder, + SerializedModelData &gd); + +} // namespace tflite_detail +} // namespace exo + +#endif // __TFL_TENSOR_EXPORTER_H__ diff --git a/compiler/exo/src/TFLite/TFLTypeInference.cpp b/compiler/exo/src/TFLite/TFLTypeInference.cpp new file mode 100644 index 000000000..8d6bb8d8c --- /dev/null +++ b/compiler/exo/src/TFLite/TFLTypeInference.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLTypeInference.h" + +#include "schema_generated.h" + +#include "Dialect/Service/TFLTypeInferenceRule.h" +#include "Dialect/IR/TFLDialect.h" + +#include <loco/IR/CanonicalNode.h> +#include <loco/IR/CanonicalNodeVisitor.h> +#include <loco/IR/CanonicalDialect.h> +#include <loco/Service/TypeInference.h> + +#include <locoex/COpDialect.h> +#include <locoex/Service/COpTypeInference.h> + +#include <oops/InternalExn.h> + +#include <stdex/Memory.h> + +#include <stdexcept> +#include <type_traits> + +namespace +{ + +tflite::TensorType translateLocoTypeToTFLite(loco::DataType dtype) +{ + switch (dtype) + { + case loco::DataType::U8: + return tflite::TensorType_UINT8; + // case loco::DataType::U16: unsupported + // case loco::DataType::U32: unsupported + // case loco::DataType::U64: unsupported + case loco::DataType::S8: + return tflite::TensorType_INT8; + case loco::DataType::S16: + return tflite::TensorType_INT16; + case loco::DataType::S32: + return tflite::TensorType_INT32; + case loco::DataType::S64: + return tflite::TensorType_INT64; + case loco::DataType::FLOAT16: + return tflite::TensorType_FLOAT16; + case loco::DataType::FLOAT32: + return tflite::TensorType_FLOAT32; + // case loco::DataType::FLOAT64: unsupported + default: + break; + } + + INTERNAL_EXN_V("Trying to converte unsupported loco dtype", oops::to_uint32(dtype)); +} + +} // namespace + +namespace exo +{ + +tflite::TensorType TypeInference::get(loco::Node *node) +{ + assert(loco::dtype_known(node)); + return translateLocoTypeToTFLite(loco::dtype_get(node)); +} + +} // namespace exo diff --git a/compiler/exo/src/TFLite/TFLTypeInference.h b/compiler/exo/src/TFLite/TFLTypeInference.h new file mode 100644 index 000000000..3d3a2e480 --- /dev/null +++ b/compiler/exo/src/TFLite/TFLTypeInference.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TFL_TYPE_INFERENCE_H__ +#define __TFL_TYPE_INFERENCE_H__ + +#include "TFLExporterUtils.h" + +#include <loco/IR/Nodes.h> + +namespace exo +{ + +/** + * @brief Get the type of each node as NodeAnnotation + * + * HOW TO USE + * + * TypeInference::get(g->nodes()->at(0)); + * TypeInference::get(g->nodes()->at(...)); + */ +struct TypeInference +{ + static tflite::TensorType get(loco::Node *node); +}; + +} // namespace exo + +#endif // __TFL_TYPE_INFERENCE_H__ diff --git a/compiler/exo/src/TFLite/TFLTypeInference.test.cpp b/compiler/exo/src/TFLite/TFLTypeInference.test.cpp new file mode 100644 index 000000000..0712f0a25 --- /dev/null +++ b/compiler/exo/src/TFLite/TFLTypeInference.test.cpp @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLTypeInference.h" +#include "Pass/TypeInferencePass.h" + +#include <loco/IR/PermutingCodec.h> +#include <stdex/Memory.h> + +#include <gtest/gtest.h> + +using stdex::make_unique; + +namespace +{ + +class Sequential +{ +public: + loco::Pull *addPullLayer(const loco::DataType &dtype = loco::DataType::FLOAT32) + { + loco::Pull *pull = _graph.nodes()->create<loco::Pull>(); + + auto graph_input = _graph.inputs()->create(); + graph_input->name("graph_input"); + loco::link(graph_input, pull); + + pull->dtype(dtype); + setSampleShape(pull); + + return last(pull); + } + + loco::ReLU *addReLULayer(void) + { + loco::ReLU *relu = _graph.nodes()->create<loco::ReLU>(); + + relu->input(_last); + + return last(relu); + } + + loco::Push *addPushLayer(void) + { + loco::Push *push = _graph.nodes()->create<loco::Push>(); + + auto graph_output = _graph.outputs()->create(); + graph_output->name("graph_output"); + loco::link(graph_output, push); + + push->from(_last); + + return last(push); + } + + loco::Graph *graph() { return &_graph; } + +private: + template <typename T> uint32_t setSampleShape(T *op) + { + const uint32_t n = 1; + const uint32_t h = 100; + const uint32_t w = 100; + const uint32_t c = 3; + op->rank(4); + op->dim(0).set(n); + op->dim(1).set(c); + op->dim(2).set(h); + op->dim(3).set(w); + return n * h * w * c; + } + + template <typename T> T *last(T *node) + { + _last = node; + return node; + } + +private: + loco::Graph _graph; + loco::Node *_last; +}; + +struct TypeInferenceTest : public Sequential, public ::testing::Test +{ + virtual ~TypeInferenceTest() = default; +}; + +} // namespace + +// TypeInference SHOULD PROPAGATE type information properly +TEST_F(TypeInferenceTest, Regression_0000) +{ + auto pull = addPullLayer(loco::DataType::S8); + auto relu = addReLULayer(); + auto push = addPushLayer(); + + using namespace exo; + + TypeInferencePass type_inf_pass; + type_inf_pass.run(graph()); + + ASSERT_EQ(TypeInference::get(relu), tflite::TensorType_INT8); + ASSERT_EQ(TypeInference::get(push), tflite::TensorType_INT8); +} diff --git a/compiler/exo/src/TestGraph.h b/compiler/exo/src/TestGraph.h new file mode 100644 index 000000000..f919cc9ae --- /dev/null +++ b/compiler/exo/src/TestGraph.h @@ -0,0 +1,315 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TEST_GRAPH_H__ +#define __TEST_GRAPH_H__ + +#include "Dialect/IR/TFLNodes.h" +#include "GraphBlock.h" +#include "TestHelper.h" + +#include <loco.h> + +#include <stdex/Memory.h> + +#include <cassert> + +namespace exo +{ +namespace test +{ + +class TestGraph +{ +public: + std::unique_ptr<loco::Graph> g; + loco::Pull *pull; + loco::Push *push; + + TestGraph() // creates Pull and Push + { + g = loco::make_graph(); + + pull = g->nodes()->create<loco::Pull>(); + + push = g->nodes()->create<loco::Push>(); + + auto input = g->inputs()->create(); + { + input->name("input"); + loco::link(input, pull); + } + auto output = g->outputs()->create(); + { + output->name("output"); + loco::link(output, push); + } + + _next_input = pull; + } + + loco::Graph *graph() { return g.get(); } + + /// @brief Creates node with NO arg and appends it to graph + template <class T> T *append() + { + auto node = g->nodes()->create<T>(); + _next_input = node; + + return node; + } + + /// @brief Creates op T (arity=1) with arg1 as an input and appends it to graph + template <class T> T *append(loco::Node *arg1) + { + auto node = g->nodes()->create<T>(); + setInput(node, arg1); + _next_input = node; + + return node; + } + + /// @brief Creates op T (arity=2) with arg1, arg2 as inputs and appends it to graph + template <class T> T *append(loco::Node *arg1, loco::Node *arg2) + { + auto node = g->nodes()->create<T>(); + setInput(node, arg1, arg2); + _next_input = node; + + return node; + } + + /// @brief Creates op T (arity=3) with arg1, arg2, arg3 as inputs and appends it to graph + template <class T> T *append(loco::Node *arg1, loco::Node *arg2, loco::Node *arg3) + { + auto node = g->nodes()->create<T>(); + setInput(node, arg1, arg2, arg3); + _next_input = node; + + return node; + } + + // push will get the last appended node + void complete() { push->from(_next_input); } + + void complete(loco::Node *last_node) { push->from(last_node); } + +private: + // arity 1 + void setInput(loco::Node *node, loco::Node *) { assert(false && "NYI"); }; + + void setInput(loco::AvgPool2D *node, loco::Node *input) { node->ifm(input); } + void setInput(loco::BiasDecode *node, loco::Node *input) { node->input(input); }; + void setInput(loco::BiasEncode *node, loco::Node *input) { node->input(input); }; + void setInput(loco::FeatureDecode *node, loco::Node *input) { node->input(input); }; + void setInput(loco::FeatureEncode *node, loco::Node *input) { node->input(input); }; + void setInput(loco::MaxPool2D *node, loco::Node *input) { node->ifm(input); } + void setInput(loco::Push *node, loco::Node *input) { node->from(input); }; + void setInput(loco::ReLU *node, loco::Node *input) { node->input(input); }; + void setInput(loco::ReLU6 *node, loco::Node *input) { node->input(input); }; + void setInput(loco::Tanh *node, loco::Node *input) { node->input(input); }; + void setInput(loco::TensorTranspose *node, loco::Node *input) { node->input(input); }; + + void setInput(locoex::TFLAveragePool2D *node, loco::Node *input) { node->value(input); }; + void setInput(locoex::TFLMaxPool2D *node, loco::Node *input) { node->value(input); }; + void setInput(locoex::TFLRelu *node, loco::Node *input) { node->features(input); }; + void setInput(locoex::TFLRelu6 *node, loco::Node *input) { node->features(input); }; + + // arity 2 + void setInput(loco::Node *node, loco::Node *, loco::Node *) { assert(false && "NYI"); }; + + void setInput(loco::Conv2D *node, loco::Node *input, loco::Node *filter) + { + node->ifm(input); + node->ker(filter); + } + + void setInput(loco::EltwiseAdd *node, loco::Node *arg1, loco::Node *arg2) + { + node->lhs(arg1); + node->rhs(arg2); + }; + + void setInput(loco::FeatureBiasAdd *node, loco::Node *arg1, loco::Node *arg2) + { + node->value(arg1); + node->bias(arg2); + }; + + void setInput(locoex::TFLAdd *node, loco::Node *arg1, loco::Node *arg2) + { + node->x(arg1); + node->y(arg2); + }; + + void setInput(locoex::TFLMul *node, loco::Node *arg1, loco::Node *arg2) + { + node->x(arg1); + node->y(arg2); + }; + + void setInput(locoex::TFLSub *node, loco::Node *arg1, loco::Node *arg2) + { + node->x(arg1); + node->y(arg2); + }; + + void setInput(locoex::TFLTranspose *node, loco::Node *arg1, loco::Node *arg2) + { + node->a(arg1); + node->perm(arg2); + }; + + // arity 3 + void setInput(loco::Node *node, loco::Node *, loco::Node *, loco::Node *) + { + assert(false && "NYI"); + }; + + void setInput(locoex::TFLConv2D *node, loco::Node *input, loco::Node *filter, loco::Node *bias) + { + node->input(input); + node->filter(filter); + node->bias(bias); + } + +private: + loco::Node *_next_input; +}; + +enum class ExampleGraphType +{ + FeatureBiasAdd, + ConstGen_ReLU, + FilterEncode_FilterDecode, + Transpose, + + TFLTranspose, +}; + +template <ExampleGraphType T> class ExampleGraph; + +/** + * @brief Class to create the following: + * + * Pull - FeatureEncoder - FeatureBiasAdd - FeatureDecode - Push + * | + * ConstGen - BiasEncode --+ + */ +template <> class ExampleGraph<ExampleGraphType::FeatureBiasAdd> : public TestGraph +{ +public: + loco::FeatureEncode *fea_enc = nullptr; + loco::ConstGen *constgen = nullptr; + loco::BiasEncode *bias_enc = nullptr; + loco::FeatureBiasAdd *fea_bias_add = nullptr; + loco::FeatureDecode *fea_dec = nullptr; + +public: + ExampleGraph() + { + fea_enc = exo::make_feature_encode<exo::FeatureLayout::NHWC>(pull); + constgen = append<loco::ConstGen>(); + bias_enc = append<loco::BiasEncode>(constgen); + fea_bias_add = append<loco::FeatureBiasAdd>(fea_enc, bias_enc); + fea_dec = exo::make_feature_decode<exo::FeatureLayout::NHWC>(fea_bias_add); + complete(fea_dec); + } +}; + +/** + * @brief Class to creates the following: + * + * ConstGen -- ReLU -- Push + */ +template <> class ExampleGraph<ExampleGraphType::ConstGen_ReLU> : public TestGraph +{ +public: + loco::ConstGen *constgen = nullptr; + loco::ReLU *relu = nullptr; + +public: + ExampleGraph() + { + constgen = append<loco::ConstGen>(); + relu = append<loco::ReLU>(constgen); + complete(relu); + } +}; + +/** + * @brief Class to creates the following: + * + * Pull -- Transpose -- Push + */ +template <> class ExampleGraph<ExampleGraphType::Transpose> : public TestGraph +{ +public: + loco::TensorTranspose *transpose = nullptr; + +public: + ExampleGraph() + { + transpose = append<loco::TensorTranspose>(pull); + complete(transpose); + } +}; + +/** + * @brief Class to creates the following: + * + * Pull -- FilterEncode -- FilterDecode -- Push + */ +template <> class ExampleGraph<ExampleGraphType::FilterEncode_FilterDecode> : public TestGraph +{ +public: + loco::FilterEncode *filterEncode = nullptr; + loco::FilterDecode *filterDecode = nullptr; + +public: + ExampleGraph() + { + filterEncode = exo::make_filter_encode<exo::FilterLayout::HWIO>(pull); // from Tensorflow + filterDecode = + exo::make_filter_decode<exo::FilterLayout::OHWI>(filterEncode); // to Tensorflow Lite + complete(filterDecode); + } +}; + +/** + * @brief Class to create the following: + * + * Pull -- TFLTranspose -- Push + */ +template <> class ExampleGraph<ExampleGraphType::TFLTranspose> : public TestGraph +{ +public: + loco::ConstGen *const_perm = nullptr; + locoex::TFLTranspose *tfl_transpose = nullptr; + +public: + ExampleGraph() + { + const_perm = append<loco::ConstGen>(); + tfl_transpose = append<locoex::TFLTranspose>(pull, const_perm); + complete(tfl_transpose); + } +}; + +} // namespace test +} // namespace exo + +#endif // __TEST_GRAPH_H__ diff --git a/compiler/exo/src/TestHelper.h b/compiler/exo/src/TestHelper.h new file mode 100644 index 000000000..1a3de50f5 --- /dev/null +++ b/compiler/exo/src/TestHelper.h @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TEST_HELPER_H__ +#define __TEST_HELPER_H__ + +#include "Check.h" +#include "ProgressReporter.h" +#include "Passes.h" + +#include <logo/Pass.h> +#include <logo/Phase.h> + +#include <loco.h> + +#include <stdex/Memory.h> + +#include <gtest/gtest.h> + +/** + * @brief Check the number of nodes in a graph starting from OUTPUTS + */ +#define EXO_TEST_ASSERT_NODE_COUNT(OUTPUTS, COUNT) \ + { \ + auto v = loco::postorder_traversal(OUTPUTS); \ + ASSERT_EQ(v.size(), (COUNT)); \ + } + +namespace exo +{ +namespace test +{ + +/** + * @brief Phase for test, that is used to test pass. This phase initially adds TypeInferencePass + * and ShapeInferencePass + */ +class TypeShapeReadyPhase +{ +public: + TypeShapeReadyPhase() + { + // Type and Shape inference is prerequisite for run other test + _phase.emplace_back(stdex::make_unique<::exo::TypeInferencePass>()); + _phase.emplace_back(stdex::make_unique<::exo::ShapeInferencePass>()); + } + + template <typename PassT> void add_pass() { _phase.emplace_back(stdex::make_unique<PassT>()); } + + void run(loco::Graph *g) + { + const auto restart = logo::PhaseStrategy::Restart; + logo::PhaseRunner<restart> phase_runner{g}; + + ::exo::ProgressReporter prog(g, restart); + phase_runner.attach(&prog); + phase_runner.run(_phase); + } + +private: + logo::Phase _phase; +}; + +/** + * @brief Get the only succ object of type LocoNodeT. (The name `only succ` comes from English word + * `only child`.) + * parent must have 1 succ only. + * When there is no succ of type LocoNodeT, nullptr will be returned. + */ +template <typename LocoNodeT> inline LocoNodeT *get_only_succ(loco::Node *parent) +{ + auto succs = loco::succs(parent); + EXO_ASSERT(succs.size() == 1, "parent has more than 1 succs."); + + return dynamic_cast<LocoNodeT *>(*succs.begin()); +} + +template <typename T> inline T *find_first_node_bytype(loco::Graph *g) +{ + T *first_node = nullptr; + loco::Graph::NodeContext *nodes = g->nodes(); + uint32_t count = nodes->size(); + + for (uint32_t i = 0; i < count; ++i) + { + first_node = dynamic_cast<T *>(nodes->at(i)); + if (first_node != nullptr) + break; + } + + return first_node; +} + +} // namespace test +} // namespace exo + +#endif // __TEST_HELPER_H__ |