summaryrefslogtreecommitdiff
path: root/compiler/exo/src
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2020-04-23 14:45:49 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2020-04-23 14:45:49 +0900
commite2ef8438a24f7c56a0744eb579a6e293ee2fbf8e (patch)
tree44a1a7951d168dd4370e13593ed03f4bc6d920c5 /compiler/exo/src
parent302e6564a7a76109e1178207e44e45a58631c477 (diff)
downloadnnfw-e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e.tar.gz
nnfw-e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e.tar.bz2
nnfw-e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e.zip
Imported Upstream version 1.4.0upstream/1.4.0submit/tizen/20200423.054851
Diffstat (limited to 'compiler/exo/src')
-rw-r--r--compiler/exo/src/Check.h37
-rw-r--r--compiler/exo/src/Circle/CircleExporter.cpp49
-rw-r--r--compiler/exo/src/Circle/CircleExporterImpl.cpp181
-rw-r--r--compiler/exo/src/Circle/CircleExporterImpl.h78
-rw-r--r--compiler/exo/src/Circle/CircleExporterUtils.cpp163
-rw-r--r--compiler/exo/src/Circle/CircleExporterUtils.h120
-rw-r--r--compiler/exo/src/Circle/CircleOperationExporter.cpp1228
-rw-r--r--compiler/exo/src/Circle/CircleOperationExporter.h39
-rw-r--r--compiler/exo/src/Circle/CircleTensorExporter.cpp261
-rw-r--r--compiler/exo/src/Circle/CircleTensorExporter.h42
-rw-r--r--compiler/exo/src/Circle/CircleTypeInference.cpp85
-rw-r--r--compiler/exo/src/Circle/CircleTypeInference.h45
-rw-r--r--compiler/exo/src/Conversion/AvgPool2DConverter.cpp79
-rw-r--r--compiler/exo/src/Conversion/AvgPool2DConverter.h41
-rw-r--r--compiler/exo/src/Conversion/CanonicalNodeConverter.cpp19
-rw-r--r--compiler/exo/src/Conversion/CanonicalNodeConverter.h71
-rw-r--r--compiler/exo/src/Conversion/ConstGenConverter.cpp60
-rw-r--r--compiler/exo/src/Conversion/ConstGenConverter.h38
-rw-r--r--compiler/exo/src/Conversion/ConstGenConverter.test.cpp65
-rw-r--r--compiler/exo/src/Conversion/Conv2DConverter.cpp97
-rw-r--r--compiler/exo/src/Conversion/Conv2DConverter.h41
-rw-r--r--compiler/exo/src/Conversion/DepthwiseConv2DConverter.cpp114
-rw-r--r--compiler/exo/src/Conversion/DepthwiseConv2DConverter.h61
-rw-r--r--compiler/exo/src/Conversion/EltwiseAddConverter.cpp29
-rw-r--r--compiler/exo/src/Conversion/EltwiseAddConverter.h41
-rw-r--r--compiler/exo/src/Conversion/EltwiseBinaryConverter.h110
-rw-r--r--compiler/exo/src/Conversion/EltwiseDivConverter.cpp29
-rw-r--r--compiler/exo/src/Conversion/EltwiseDivConverter.h41
-rw-r--r--compiler/exo/src/Conversion/EltwiseMaxConverter.cpp75
-rw-r--r--compiler/exo/src/Conversion/EltwiseMaxConverter.h41
-rw-r--r--compiler/exo/src/Conversion/EltwiseMulConverter.cpp29
-rw-r--r--compiler/exo/src/Conversion/EltwiseMulConverter.h41
-rw-r--r--compiler/exo/src/Conversion/EltwiseSqrtConverter.cpp68
-rw-r--r--compiler/exo/src/Conversion/EltwiseSqrtConverter.h41
-rw-r--r--compiler/exo/src/Conversion/EltwiseSubConverter.cpp29
-rw-r--r--compiler/exo/src/Conversion/EltwiseSubConverter.h41
-rw-r--r--compiler/exo/src/Conversion/FeatureBiasAddConverter.cpp91
-rw-r--r--compiler/exo/src/Conversion/FeatureBiasAddConverter.h38
-rw-r--r--compiler/exo/src/Conversion/FeatureBiasAddConverter.test.cpp102
-rw-r--r--compiler/exo/src/Conversion/MatMulConverter.cpp103
-rw-r--r--compiler/exo/src/Conversion/MatMulConverter.h41
-rw-r--r--compiler/exo/src/Conversion/MaxPool2DConverter.cpp67
-rw-r--r--compiler/exo/src/Conversion/MaxPool2DConverter.h41
-rw-r--r--compiler/exo/src/Conversion/Relu6Converter.cpp68
-rw-r--r--compiler/exo/src/Conversion/Relu6Converter.h41
-rw-r--r--compiler/exo/src/Conversion/ReluConverter.cpp68
-rw-r--r--compiler/exo/src/Conversion/ReluConverter.h41
-rw-r--r--compiler/exo/src/Conversion/ReluConverter.test.cpp97
-rw-r--r--compiler/exo/src/Conversion/TensorBroadcastConverter.cpp189
-rw-r--r--compiler/exo/src/Conversion/TensorBroadcastConverter.h40
-rw-r--r--compiler/exo/src/Conversion/TensorConcatConverter.cpp66
-rw-r--r--compiler/exo/src/Conversion/TensorConcatConverter.h41
-rw-r--r--compiler/exo/src/Conversion/TensorReduceConverter.cpp95
-rw-r--r--compiler/exo/src/Conversion/TensorReduceConverter.h46
-rw-r--r--compiler/exo/src/Conversion/TensorTransposeConverter.cpp102
-rw-r--r--compiler/exo/src/Conversion/TensorTransposeConverter.h41
-rw-r--r--compiler/exo/src/Conversion/TransposedConv2DConverter.cpp92
-rw-r--r--compiler/exo/src/Conversion/TransposedConv2DConverter.h62
-rw-r--r--compiler/exo/src/Conversions.h46
-rw-r--r--compiler/exo/src/Convert.cpp97
-rw-r--r--compiler/exo/src/Convert.h29
-rw-r--r--compiler/exo/src/Dialect/IR/CircleDialect.cpp28
-rw-r--r--compiler/exo/src/Dialect/IR/CircleDialect.h40
-rw-r--r--compiler/exo/src/Dialect/IR/CircleDialect.test.cpp31
-rw-r--r--compiler/exo/src/Dialect/IR/CircleNode.cpp26
-rw-r--r--compiler/exo/src/Dialect/IR/CircleNode.h23
-rw-r--r--compiler/exo/src/Dialect/IR/CircleNodeDecl.h50
-rw-r--r--compiler/exo/src/Dialect/IR/CircleNodeImpl.h70
-rw-r--r--compiler/exo/src/Dialect/IR/CircleNodeVisitor.forward.h30
-rw-r--r--compiler/exo/src/Dialect/IR/CircleNodeVisitor.h86
-rw-r--r--compiler/exo/src/Dialect/IR/CircleNodes.cpp18
-rw-r--r--compiler/exo/src/Dialect/IR/CircleNodes.h79
-rw-r--r--compiler/exo/src/Dialect/IR/CircleNodes.lst8
-rw-r--r--compiler/exo/src/Dialect/IR/CircleNodes.test.cpp36
-rw-r--r--compiler/exo/src/Dialect/IR/CircleOpcode.h32
-rw-r--r--compiler/exo/src/Dialect/IR/FusedActFunc.h35
-rw-r--r--compiler/exo/src/Dialect/IR/NodeMixins.cpp18
-rw-r--r--compiler/exo/src/Dialect/IR/NodeMixins.h66
-rw-r--r--compiler/exo/src/Dialect/IR/TFLDialect.cpp28
-rw-r--r--compiler/exo/src/Dialect/IR/TFLDialect.h40
-rw-r--r--compiler/exo/src/Dialect/IR/TFLDialect.test.cpp31
-rw-r--r--compiler/exo/src/Dialect/IR/TFLNode.cpp26
-rw-r--r--compiler/exo/src/Dialect/IR/TFLNode.h23
-rw-r--r--compiler/exo/src/Dialect/IR/TFLNodeDecl.h50
-rw-r--r--compiler/exo/src/Dialect/IR/TFLNodeImpl.h70
-rw-r--r--compiler/exo/src/Dialect/IR/TFLNodeVisitor.forward.h30
-rw-r--r--compiler/exo/src/Dialect/IR/TFLNodeVisitor.h86
-rw-r--r--compiler/exo/src/Dialect/IR/TFLNodes.cpp91
-rw-r--r--compiler/exo/src/Dialect/IR/TFLNodes.h551
-rw-r--r--compiler/exo/src/Dialect/IR/TFLNodes.lst30
-rw-r--r--compiler/exo/src/Dialect/IR/TFLNodes.test.cpp159
-rw-r--r--compiler/exo/src/Dialect/IR/TFLOpcode.h32
-rw-r--r--compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.cpp67
-rw-r--r--compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.h33
-rw-r--r--compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.cpp58
-rw-r--r--compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.h36
-rw-r--r--compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp627
-rw-r--r--compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.h33
-rw-r--r--compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.test.cpp277
-rw-r--r--compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp141
-rw-r--r--compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.h37
-rw-r--r--compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.test.cpp57
-rw-r--r--compiler/exo/src/ExoFormattedGraph.cpp525
-rw-r--r--compiler/exo/src/ExoFormattedGraph.h56
-rw-r--r--compiler/exo/src/ExoOptimize.cpp74
-rw-r--r--compiler/exo/src/ExoOptimize.h34
-rw-r--r--compiler/exo/src/ExporterUtils.cpp139
-rw-r--r--compiler/exo/src/ExporterUtils.h57
-rw-r--r--compiler/exo/src/GraphBlock.cpp243
-rw-r--r--compiler/exo/src/GraphBlock.h199
-rw-r--r--compiler/exo/src/Knob.cpp122
-rw-r--r--compiler/exo/src/Knob.h51
-rw-r--r--compiler/exo/src/Knob.lst11
-rw-r--r--compiler/exo/src/Log.cpp84
-rw-r--r--compiler/exo/src/Log.h75
-rw-r--r--compiler/exo/src/LogHelper.cpp79
-rw-r--r--compiler/exo/src/LogHelper.h70
-rw-r--r--compiler/exo/src/LoggingContext.cpp40
-rw-r--r--compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp116
-rw-r--r--compiler/exo/src/Pass/FoldReshapeOfConstPass.h46
-rw-r--r--compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp154
-rw-r--r--compiler/exo/src/Pass/FoldTransposeOfConstPass.h46
-rw-r--r--compiler/exo/src/Pass/FuseBiasAddPass.cpp362
-rw-r--r--compiler/exo/src/Pass/FuseBiasAddPass.h61
-rw-r--r--compiler/exo/src/Pass/FuseBiasAddPass.test.cpp361
-rw-r--r--compiler/exo/src/Pass/FuseInstanceNormPass.cpp402
-rw-r--r--compiler/exo/src/Pass/FuseInstanceNormPass.h40
-rw-r--r--compiler/exo/src/Pass/FuseReluPass.cpp115
-rw-r--r--compiler/exo/src/Pass/FuseReluPass.h40
-rw-r--r--compiler/exo/src/Pass/FuseReluPass.test.cpp115
-rw-r--r--compiler/exo/src/Pass/FuseRsqrtPass.cpp95
-rw-r--r--compiler/exo/src/Pass/FuseRsqrtPass.h47
-rw-r--r--compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp86
-rw-r--r--compiler/exo/src/Pass/FuseSquaredDifferencePass.h49
-rw-r--r--compiler/exo/src/Pass/MergeConcatNodesPass.cpp191
-rw-r--r--compiler/exo/src/Pass/MergeConcatNodesPass.h41
-rw-r--r--compiler/exo/src/Pass/ShapeInferencePass.cpp59
-rw-r--r--compiler/exo/src/Pass/ShapeInferencePass.h40
-rw-r--r--compiler/exo/src/Pass/TypeInferencePass.cpp57
-rw-r--r--compiler/exo/src/Pass/TypeInferencePass.h42
-rw-r--r--compiler/exo/src/Passes.cpp19
-rw-r--r--compiler/exo/src/Passes.h38
-rw-r--r--compiler/exo/src/ProgressReporter.cpp84
-rw-r--r--compiler/exo/src/ProgressReporter.h53
-rw-r--r--compiler/exo/src/ShapeInference.cpp44
-rw-r--r--compiler/exo/src/ShapeInference.h41
-rw-r--r--compiler/exo/src/TFLite/TFLExporter.cpp49
-rw-r--r--compiler/exo/src/TFLite/TFLExporterImpl.cpp179
-rw-r--r--compiler/exo/src/TFLite/TFLExporterImpl.h78
-rw-r--r--compiler/exo/src/TFLite/TFLExporterImpl.test.cpp413
-rw-r--r--compiler/exo/src/TFLite/TFLExporterUtils.cpp160
-rw-r--r--compiler/exo/src/TFLite/TFLExporterUtils.h118
-rw-r--r--compiler/exo/src/TFLite/TFLExporterUtils.test.cpp108
-rw-r--r--compiler/exo/src/TFLite/TFLOperationExporter.cpp1199
-rw-r--r--compiler/exo/src/TFLite/TFLOperationExporter.h39
-rw-r--r--compiler/exo/src/TFLite/TFLTensorExporter.cpp249
-rw-r--r--compiler/exo/src/TFLite/TFLTensorExporter.h42
-rw-r--r--compiler/exo/src/TFLite/TFLTypeInference.cpp82
-rw-r--r--compiler/exo/src/TFLite/TFLTypeInference.h42
-rw-r--r--compiler/exo/src/TFLite/TFLTypeInference.test.cpp118
-rw-r--r--compiler/exo/src/TestGraph.h315
-rw-r--r--compiler/exo/src/TestHelper.h110
162 files changed, 16666 insertions, 0 deletions
diff --git a/compiler/exo/src/Check.h b/compiler/exo/src/Check.h
new file mode 100644
index 000000000..79dac50dd
--- /dev/null
+++ b/compiler/exo/src/Check.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CHECK_H__
+#define __CHECK_H__
+
+#include <pepper/str.h>
+
+#include <stdexcept>
+#include <cassert>
+#include <iostream>
+
+// TODO Add macro for Release version
+
+#define EXO_ASSERT(condition, msg) \
+ { \
+ if (!(condition)) \
+ { \
+ std::cerr << "[assert failed] " << (msg) << ". " << std::endl; \
+ assert((condition)); \
+ } \
+ }
+
+#endif // __CHECK_H__
diff --git a/compiler/exo/src/Circle/CircleExporter.cpp b/compiler/exo/src/Circle/CircleExporter.cpp
new file mode 100644
index 000000000..797749090
--- /dev/null
+++ b/compiler/exo/src/Circle/CircleExporter.cpp
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "exo/CircleExporter.h"
+
+#include "CircleExporterImpl.h"
+
+#include <stdex/Memory.h>
+
+#include <oops/InternalExn.h>
+
+#include <fstream>
+
+namespace exo
+{
+
+CircleExporter::CircleExporter(loco::Graph *graph) : _impl(stdex::make_unique<Impl>(graph))
+{
+ // NOTHING TO DO
+}
+
+CircleExporter::~CircleExporter() = default;
+
+void CircleExporter::dumpToFile(const char *path) const
+{
+ const char *ptr = _impl->getBufferPointer();
+ const size_t size = _impl->getBufferSize();
+
+ if (!ptr)
+ INTERNAL_EXN("Graph was not serialized by FlatBuffer for some reason");
+
+ std::ofstream file(path, std::ofstream::binary);
+ file.write(ptr, size);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Circle/CircleExporterImpl.cpp b/compiler/exo/src/Circle/CircleExporterImpl.cpp
new file mode 100644
index 000000000..4cba33da1
--- /dev/null
+++ b/compiler/exo/src/Circle/CircleExporterImpl.cpp
@@ -0,0 +1,181 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleExporterImpl.h"
+
+#include "Convert.h"
+#include "ExoOptimize.h"
+
+#include "CircleTensorExporter.h"
+#include "CircleOperationExporter.h"
+#include "CircleExporterUtils.h"
+
+#include "Log.h"
+#include "Knob.h"
+
+#include <oops/InternalExn.h>
+
+#include <cassert>
+#include <unordered_map>
+#include <string>
+#include <stdexcept>
+
+namespace
+{
+
+using namespace exo::circle_detail;
+
+void registerGraphInputTensors(loco::Graph *graph, SubGraphContext &ctx)
+{
+ for (uint32_t n = 0; n < graph->inputs()->size(); ++n)
+ {
+ auto node = loco::pull_node(graph, n);
+ assert(node != nullptr);
+ ctx._inputs.push_back(get_tensor_index(node));
+ }
+}
+
+void registerGraphOutputTensors(loco::Graph *graph, SubGraphContext &ctx)
+{
+ for (uint32_t n = 0; n < graph->outputs()->size(); ++n)
+ {
+ auto push = loco::push_node(graph, n);
+ assert(push != nullptr);
+ auto node = push->from();
+ assert(node != nullptr);
+ ctx._outputs.push_back(get_tensor_index(node));
+ }
+}
+
+} // namespace
+
+namespace
+{
+
+using namespace circle;
+using namespace flatbuffers;
+
+Offset<Vector<Offset<OperatorCode>>>
+encodeOperatorCodes(FlatBufferBuilder &builder, std::unordered_map<OpCode, uint32_t> &opcodes,
+ std::unordered_map<OpCode, std::string> &custom_opcodes)
+{
+ std::vector<Offset<OperatorCode>> operator_codes_vec(opcodes.size());
+ for (auto it : opcodes)
+ {
+ uint32_t idx = it.second;
+ if (it.first.opcode != BuiltinOperator_CUSTOM)
+ {
+ operator_codes_vec[idx] = CreateOperatorCode(builder, it.first.opcode);
+ }
+ else // custom op
+ {
+ auto opCode = it.first;
+ auto custom_code = custom_opcodes.find(opCode);
+ if (custom_code == custom_opcodes.end())
+ INTERNAL_EXN("Cannot find code for customop even though opcode is BuiltinOperator_CUSTOM");
+
+ operator_codes_vec[idx] =
+ CreateOperatorCode(builder, it.first.opcode, builder.CreateString(custom_code->second));
+ }
+ }
+ return builder.CreateVector(operator_codes_vec);
+}
+
+} // namespace
+
+namespace exo
+{
+
+using namespace exo::circle_detail;
+using namespace circle;
+using namespace flatbuffers;
+
+CircleExporter::Impl::Impl(loco::Graph *graph) { exportGraph(graph); }
+
+::flatbuffers::Offset<::circle::SubGraph>
+CircleExporter::Impl::exportSubgraph(SerializedModelData &gd)
+{
+ auto tensors = _builder.CreateVector(gd._tensors);
+ auto inputs = _builder.CreateVector(gd._inputs);
+ auto outputs = _builder.CreateVector(gd._outputs);
+ auto operators = _builder.CreateVector(gd._operators);
+ auto df = gd._data_format;
+ auto subgraph = CreateSubGraph(_builder, tensors, inputs, outputs, operators, df);
+ return subgraph;
+}
+
+void CircleExporter::Impl::exportGraph(loco::Graph *graph)
+{
+ LOGGER(l);
+
+ // IR-level conversion and optimization
+ {
+ convert_to_TFLNodes(graph);
+ set(Dialect::CIRCLE);
+ optimize(graph);
+ }
+
+ _builder.Clear();
+
+ SerializedModelData gd;
+
+ // This version is taken from comment in fbs
+ constexpr uint32_t version = 0;
+
+ registerGraphIOName(graph, gd);
+
+ // parse graph into SerializedModelData structure
+ exportOpDefinedTensors(graph, _builder, gd);
+
+ // NOTE Invoke these register functions only after each node is annotated with its tensor_index
+ registerGraphInputTensors(graph, gd);
+ registerGraphOutputTensors(graph, gd);
+
+ exportNodes(graph, _builder, gd);
+
+ // encode operator codes
+ auto operator_codes =
+ encodeOperatorCodes(_builder, gd._operator_codes, gd._custom_operator_codes);
+
+ // Subgraphs
+ Offset<SubGraph> subgraph = exportSubgraph(gd);
+ auto subgraphs = _builder.CreateVector(std::vector<Offset<SubGraph>>{subgraph});
+
+ // Description
+ std::string description_str = "nnpackage";
+ auto description = _builder.CreateString(description_str);
+
+ // create array of buffers
+ auto buffers = _builder.CreateVector(gd._buffers);
+
+ // empty metadata
+ std::vector<int> metadata_buffer_vec;
+ auto metadata_buffer = _builder.CreateVector(metadata_buffer_vec);
+
+ // Model
+ auto model_offset = CreateModel(_builder, version, operator_codes, subgraphs, description,
+ buffers, metadata_buffer);
+ FinishModelBuffer(_builder, model_offset);
+}
+
+const char *CircleExporter::Impl::getBufferPointer() const
+{
+ return reinterpret_cast<const char *>(_builder.GetBufferPointer());
+}
+
+size_t CircleExporter::Impl::getBufferSize() const { return _builder.GetSize(); }
+
+} // namespace exo
diff --git a/compiler/exo/src/Circle/CircleExporterImpl.h b/compiler/exo/src/Circle/CircleExporterImpl.h
new file mode 100644
index 000000000..b1138fbad
--- /dev/null
+++ b/compiler/exo/src/Circle/CircleExporterImpl.h
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_EXPORTER_IMPL_H__
+#define __CIRCLE_EXPORTER_IMPL_H__
+
+#include "exo/CircleExporter.h"
+#include "circle_schema_generated.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+namespace circle_detail
+{
+
+struct SerializedModelData;
+
+} // namespace circle_detail
+
+using namespace circle_detail;
+
+/**
+ * internal implementation of interface exporter class
+ */
+class CircleExporter::Impl
+{
+public:
+ Impl() = delete;
+ ~Impl() = default;
+
+ explicit Impl(loco::Graph *graph);
+
+ /**
+ * @return pointer to buffer with serialized graph
+ */
+ const char *getBufferPointer() const;
+
+ /**
+ * @return size of buffer with serialized graph
+ */
+ size_t getBufferSize() const;
+
+private:
+ /**
+ * @brief create Subgraph using data stored in SerializedModelData
+ * @param gd information about serializer parts of model
+ * @return offset in buffer corresponding to serialized subgraph
+ */
+ flatbuffers::Offset<circle::SubGraph> exportSubgraph(SerializedModelData &gd);
+
+ /**
+ * @brief root function that writes graph into internal buffer
+ * @param graph
+ */
+ void exportGraph(loco::Graph *graph);
+
+private:
+ flatbuffers::FlatBufferBuilder _builder;
+};
+
+} // namespace exo
+
+#endif // __CIRCLE_EXPORTER_IMPL_H__
diff --git a/compiler/exo/src/Circle/CircleExporterUtils.cpp b/compiler/exo/src/Circle/CircleExporterUtils.cpp
new file mode 100644
index 000000000..12b204ce7
--- /dev/null
+++ b/compiler/exo/src/Circle/CircleExporterUtils.cpp
@@ -0,0 +1,163 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleExporterUtils.h"
+
+#include <oops/InternalExn.h>
+
+namespace exo
+{
+
+circle::ActivationFunctionType to_circle_actfunc(locoex::FusedActFunc func)
+{
+ switch (func)
+ {
+ case locoex::FusedActFunc::NONE:
+ return circle::ActivationFunctionType_NONE;
+ case locoex::FusedActFunc::RELU:
+ return circle::ActivationFunctionType_RELU;
+ case locoex::FusedActFunc::RELU6:
+ return circle::ActivationFunctionType_RELU6;
+ default:
+ INTERNAL_EXN_V("trying to convert unsupported locoex::FusedActFunc", oops::to_uint32(func));
+ }
+}
+
+} // namespace exo
+
+namespace exo
+{
+namespace circle_detail
+{
+
+uint32_t SerializedModelData::registerBuiltinOpcode(circle::BuiltinOperator builtin_code)
+{
+ auto it = _operator_codes.find(OpCode{builtin_code});
+ if (it != _operator_codes.end())
+ {
+ return it->second;
+ }
+ auto idx = static_cast<uint32_t>(_operator_codes.size());
+ _operator_codes.emplace(OpCode{builtin_code}, idx);
+ return idx;
+}
+
+uint32_t SerializedModelData::registerCustomOpcode(const std::string &custom_op)
+{
+ circle::BuiltinOperator custom_code = circle::BuiltinOperator_CUSTOM;
+ auto idx = registerBuiltinOpcode(custom_code);
+ _custom_operator_codes.emplace(OpCode{custom_code}, custom_op);
+ return idx;
+}
+
+circle::Padding getOpPadding(const loco::Padding2D *pad, const loco::Stride<2> *stride,
+ const ShapeDescription &ifm, const ShapeDescription &ofm)
+{
+ // VALID padding
+ if (pad->top() == 0 && pad->bottom() == 0 && pad->left() == 0 && pad->right() == 0)
+ return circle::Padding_VALID;
+
+ // SAME padding
+ //
+ // For same padding, by definition, following equation should hold:
+ // O = floor((I - 1) / S) + 1
+ // where input size I, output size O, stride S
+ //
+ // NOTE input and output 'feature' map are shape of NHWC
+ bool same_padding_criterion_1 =
+ (static_cast<uint32_t>(ofm._dims[1]) == (ifm._dims[1] - 1) / stride->vertical() + 1) &&
+ (static_cast<uint32_t>(ofm._dims[2]) == (ifm._dims[2] - 1) / stride->horizontal() + 1);
+
+ // For same padding, rear padding is same or bigger than front padding by at most 1
+ bool same_padding_criterion_2 =
+ (pad->top() <= pad->bottom()) && (pad->bottom() <= pad->top() + 1) &&
+ (pad->left() <= pad->right()) && (pad->right() <= pad->left() + 1);
+
+ if (same_padding_criterion_1 && same_padding_criterion_2)
+ return circle::Padding_SAME;
+
+ INTERNAL_EXN("Unsupported padding criteria");
+}
+
+circle::Padding getOpPadding(const locoex::Padding pad)
+{
+ if (pad == locoex::Padding::VALID)
+ return circle::Padding_VALID;
+ if (pad == locoex::Padding::SAME)
+ return circle::Padding_SAME;
+
+ INTERNAL_EXN_V("Unsupported locoex::Padding", oops::to_uint32(pad));
+}
+
+void registerGraphIOName(loco::Graph *graph, SerializedModelData &gd)
+{
+ for (uint32_t in = 0; in < graph->inputs()->size(); ++in)
+ {
+ auto pull = loco::pull_node(graph, in);
+ auto name = graph->inputs()->at(in)->name();
+
+ gd._pull_to_name[pull] = name;
+ }
+ for (uint32_t out = 0; out < graph->outputs()->size(); ++out)
+ {
+ auto push = loco::push_node(graph, out);
+ auto name = graph->outputs()->at(out)->name();
+
+ gd._push_to_name[push] = name;
+ }
+
+ // TODO set this value properly
+ gd._data_format = circle::DataFormat::DataFormat_CHANNELS_LAST;
+}
+
+#include <stdex/Memory.h>
+
+#include <cassert>
+
+namespace
+{
+
+class TFLTensorIndexAnnotation final : public loco::NodeAnnotation
+{
+public:
+ TFLTensorIndexAnnotation(const TFLTensorIndex &index) : _index{index}
+ {
+ // DO NOTHING
+ }
+
+public:
+ const TFLTensorIndex &index(void) const { return _index; }
+
+private:
+ TFLTensorIndex _index;
+};
+
+} // namespace
+
+void set_tensor_index(loco::Node *node, const TFLTensorIndex &tensor_id)
+{
+ assert(node->annot<TFLTensorIndexAnnotation>() == nullptr);
+ node->annot(stdex::make_unique<TFLTensorIndexAnnotation>(tensor_id));
+}
+
+TFLTensorIndex get_tensor_index(loco::Node *node)
+{
+ assert(node->annot<TFLTensorIndexAnnotation>() != nullptr);
+ return node->annot<TFLTensorIndexAnnotation>()->index();
+}
+
+} // namespace circle_detail
+} // namespace exo
diff --git a/compiler/exo/src/Circle/CircleExporterUtils.h b/compiler/exo/src/Circle/CircleExporterUtils.h
new file mode 100644
index 000000000..fdd162bae
--- /dev/null
+++ b/compiler/exo/src/Circle/CircleExporterUtils.h
@@ -0,0 +1,120 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_EXPORTER_UTILS_H__
+#define __CIRCLE_EXPORTER_UTILS_H__
+
+#include "ExporterUtils.h"
+
+#include "circle_schema_generated.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include <loco.h>
+
+#include <unordered_map>
+
+namespace exo
+{
+namespace circle_detail
+{
+
+struct OpCode
+{
+ circle::BuiltinOperator opcode;
+
+ bool operator==(const OpCode &rhs) const { return opcode == rhs.opcode; }
+};
+
+} // namespace circle_detail
+} // namespace exo
+
+namespace exo
+{
+
+circle::ActivationFunctionType to_circle_actfunc(locoex::FusedActFunc func);
+
+} // namespace exo
+
+namespace std
+{
+
+template <> struct hash<exo::circle_detail::OpCode>
+{
+ size_t operator()(const exo::circle_detail::OpCode &x) const { return hash<int>()(x.opcode); }
+};
+
+} // namespace std
+
+namespace exo
+{
+namespace circle_detail
+{
+
+/**
+ * @breif Record the information of T/F Lite SubGraph and its mapping to loco
+ */
+struct SubGraphContext
+{
+ /// @brief SubGraph input tensor id
+ std::vector<int32_t> _inputs;
+ /// @brief SubGraph output tensor id
+ std::vector<int32_t> _outputs;
+ /// @DataFormat for SubGraph
+ circle::DataFormat _data_format{circle::DataFormat::DataFormat_CHANNELS_LAST};
+};
+
+// Prerequisites for circle::Model object creation
+struct SerializedModelData final : public SubGraphContext
+{
+ SerializedModelData() = default;
+ SerializedModelData(const SerializedModelData &) = delete;
+
+ std::unordered_map<OpCode, uint32_t> _operator_codes;
+ std::unordered_map<OpCode, std::string> _custom_operator_codes;
+ std::vector<flatbuffers::Offset<circle::Operator>> _operators;
+ std::vector<flatbuffers::Offset<circle::Tensor>> _tensors;
+ std::vector<flatbuffers::Offset<circle::Buffer>> _buffers;
+
+ // Graph input and output names
+ std::unordered_map<loco::Pull *, std::string> _pull_to_name;
+ std::unordered_map<loco::Push *, std::string> _push_to_name;
+
+ /**
+ * @brief if opcode is not registered in table of opcodes add it
+ * @param builtin_code
+ * @return idx of opcode in table of opcodes (see schema)
+ */
+ uint32_t registerBuiltinOpcode(circle::BuiltinOperator builtin_code);
+ uint32_t registerCustomOpcode(const std::string &custom_op);
+};
+
+circle::Padding getOpPadding(const loco::Padding2D *pad, const loco::Stride<2> *stride,
+ const ShapeDescription &ifm, const ShapeDescription &ofm);
+circle::Padding getOpPadding(const locoex::Padding pad);
+
+/// @brief Register graph input and output names to SerializedModelData
+void registerGraphIOName(loco::Graph *graph, SerializedModelData &gd);
+
+using TFLTensorIndex = int32_t;
+
+void set_tensor_index(loco::Node *node, const TFLTensorIndex &tensor_id);
+TFLTensorIndex get_tensor_index(loco::Node *node);
+
+} // namespace circle_detail
+} // namespace exo
+
+#endif // __TFL_EXPORTER_UTILS_H__
diff --git a/compiler/exo/src/Circle/CircleOperationExporter.cpp b/compiler/exo/src/Circle/CircleOperationExporter.cpp
new file mode 100644
index 000000000..390e2ec99
--- /dev/null
+++ b/compiler/exo/src/Circle/CircleOperationExporter.cpp
@@ -0,0 +1,1228 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleOperationExporter.h"
+#include "CircleExporterUtils.h"
+#include "ShapeInference.h"
+
+#include "Dialect/IR/TFLNode.h"
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+#include "Dialect/IR/CircleNode.h"
+#include "Dialect/IR/CircleNodes.h"
+#include "Dialect/IR/CircleNodeVisitor.h"
+
+#include "Check.h"
+
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+#include <loco/Service/ShapeInference.h>
+#include <locoex/COpCall.h>
+
+#include <oops/InternalExn.h>
+
+#include <flatbuffers/flexbuffers.h>
+
+using namespace flatbuffers;
+using namespace circle;
+
+namespace
+{
+
+using namespace exo;
+using namespace exo::circle_detail;
+
+class OperationExporter final : public locoex::TFLNodeMutableVisitor<void>,
+ public locoex::CircleNodeMutableVisitor<void>,
+ public loco::CanonicalNodeMutableVisitor<void>
+{
+public:
+ OperationExporter(FlatBufferBuilder &fbb, SerializedModelData &ctx) : builder{fbb}, gd{ctx}
+ {
+ // DO NOTHING
+ }
+
+public:
+ // FOR TFLNodes
+ void visit(locoex::TFLAdd *) final;
+ void visit(locoex::TFLAveragePool2D *) final;
+ void visit(locoex::TFLConcatenation *) final;
+ void visit(locoex::TFLConst *) final{/* skip, everything is done in exportOpDefinedTensors */};
+ void visit(locoex::TFLConv2D *) final;
+ void visit(locoex::TFLDepthwiseConv2D *) final;
+ void visit(locoex::TFLDiv *) final;
+ void visit(locoex::TFLFullyConnected *) final;
+ void visit(locoex::TFLMaximum *) final;
+ void visit(locoex::TFLMaxPool2D *) final;
+ void visit(locoex::TFLMean *) final;
+ void visit(locoex::TFLMul *) final;
+ void visit(locoex::TFLRelu *) final;
+ void visit(locoex::TFLRelu6 *) final;
+ // TODO TFLReshape
+ void visit(locoex::TFLRsqrt *) final;
+ // TODO TFLSoftmax
+ void visit(locoex::TFLSqrt *) final;
+ void visit(locoex::TFLSquaredDifference *) final;
+ void visit(locoex::TFLSub *) final;
+ // TODO TFLTanh
+ void visit(locoex::TFLTranspose *) final;
+ void visit(locoex::TFLTransposeConv *) final;
+
+ // FOR CircleNodes
+ void visit(locoex::CircleInstanceNorm *) final;
+
+ // FOR canonical nodes. These will be removed later
+ void visit(loco::ReLU *) final;
+ void visit(loco::ReLU6 *) final;
+ void visit(loco::Tanh *) final;
+ void visit(loco::Push *) final { /* DO NOTHING */}
+ void visit(loco::Pull *) final { /* DO NOTHING */}
+ void visit(loco::FeatureEncode *) final;
+ void visit(loco::FeatureDecode *) final;
+ void visit(loco::FilterEncode *) final;
+ void visit(loco::DepthwiseFilterEncode *) final;
+ void visit(loco::ConstGen *) final { /* skip, everything is done in exportOpDefinedTensors */}
+ void visit(loco::MaxPool2D *) final;
+ void visit(loco::AvgPool2D *) final;
+ void visit(loco::Conv2D *) final;
+ void visit(loco::TransposedConv2D *) final;
+ void visit(loco::DepthwiseConv2D *) final;
+ void visit(loco::TensorConcat *) final;
+ void visit(loco::TensorReduce *) final;
+ void visit(loco::TensorSoftmax *) final;
+ void visit(loco::BiasEncode *) final;
+ void visit(loco::TensorBiasAdd *) final;
+ void visit(loco::FeatureBiasAdd *) final;
+ void visit(loco::EltwiseAdd *) final;
+ void visit(loco::EltwiseMax *) final;
+ void visit(loco::EltwiseMul *) final;
+ void visit(loco::EltwiseSub *) final;
+ void visit(loco::EltwiseDiv *) final;
+ void visit(loco::EltwiseSqrt *) final;
+ void visit(loco::FixedReshape *) final;
+ void visit(loco::TensorBroadcast *) final;
+ void visit(loco::TensorConstantPad *) final;
+
+ void visit(locoex::COpCall *);
+
+private:
+ /**
+ * @brief Exports TFLMaxPool2D or TFLAveragePool2D
+ *
+ * @note TFLPool2D should be one of TFLMaxPool2D or TFLAveragePool2D
+ */
+ template <class TFLPool2D>
+ void export_pool_2d(TFLPool2D *node, circle::BuiltinOperator builtin_op);
+
+private:
+ FlatBufferBuilder &builder;
+ SerializedModelData &gd;
+};
+
+void OperationExporter::visit(locoex::TFLAdd *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_ADD);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateAddOptions(builder, to_circle_actfunc(node->fusedActivationFunction()));
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_AddOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLAveragePool2D *node)
+{
+ export_pool_2d<locoex::TFLAveragePool2D>(node, circle::BuiltinOperator_AVERAGE_POOL_2D);
+}
+
+void OperationExporter::visit(locoex::TFLConcatenation *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_CONCATENATION);
+ std::vector<int32_t> inputs_vec;
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+
+ for (uint32_t i = 0; i < node->numValues(); ++i)
+ inputs_vec.push_back(get_tensor_index(node->values(i)));
+
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateConcatenationOptions(builder, node->axis(),
+ to_circle_actfunc(node->fusedActivationFunction()));
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_ConcatenationOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLConv2D *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_CONV_2D);
+
+ // Make input, output and options for operator
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->filter()),
+ get_tensor_index(node->bias())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ circle::Padding padding = getOpPadding(node->padding());
+ auto options = CreateConv2DOptions(builder, padding, node->stride()->w(), node->stride()->h(),
+ to_circle_actfunc(node->fusedActivationFunction()));
+
+ // Make CONV_2D operator
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_Conv2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLDepthwiseConv2D *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_DEPTHWISE_CONV_2D);
+
+ // Make input, output and options for operator
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->filter()),
+ get_tensor_index(node->bias())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ circle::Padding padding = getOpPadding(node->padding());
+ auto options = CreateDepthwiseConv2DOptions(builder, padding, node->stride()->w(),
+ node->stride()->h(), node->depthMultiplier(),
+ to_circle_actfunc(node->fusedActivationFunction()));
+
+ // Make DEPTHWISE_CONV_2D operator
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_DepthwiseConv2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLDiv *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_DIV);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateDivOptions(builder, to_circle_actfunc(node->fusedActivationFunction()));
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_DivOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLFullyConnected *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_FULLY_CONNECTED);
+
+ // Make input, output and options for operator
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input()),
+ get_tensor_index(node->weights()),
+ get_tensor_index(node->bias())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options =
+ CreateFullyConnectedOptions(builder, to_circle_actfunc(node->fusedActivationFunction()));
+
+ // Make FULLY_CONNECTED operator
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_FullyConnectedOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLMaximum *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MAXIMUM);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateMaximumMinimumOptions(builder);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_MaximumMinimumOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLMaxPool2D *node)
+{
+ export_pool_2d<locoex::TFLMaxPool2D>(node, circle::BuiltinOperator_MAX_POOL_2D);
+}
+
+void OperationExporter::visit(locoex::TFLMean *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MEAN);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input()),
+ get_tensor_index(node->reduction_indices())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateReducerOptions(builder, node->keep_dims());
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_ReducerOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLMul *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MUL);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateMulOptions(builder, to_circle_actfunc(node->fusedActivationFunction()));
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_MulOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLRelu *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_RELU);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->features())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLRelu6 *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_RELU6);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->features())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+// TODO TFLReshape
+
+void OperationExporter::visit(locoex::TFLRsqrt *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_RSQRT);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+// TODO TFLSoftmax
+
+void OperationExporter::visit(locoex::TFLSqrt *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_SQRT);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLSquaredDifference *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_SQUARED_DIFFERENCE);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateSquaredDifferenceOptions(builder);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_SquaredDifferenceOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLSub *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_SUB);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateSubOptions(builder, to_circle_actfunc(node->fusedActivationFunction()));
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_SubOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+// TODO TFLTanh
+
+void OperationExporter::visit(locoex::TFLTranspose *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_TRANSPOSE);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0)), get_tensor_index(node->arg(1))};
+ std::vector<int32_t> outputs_vec{get_tensor_index(node)};
+
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateTransposeOptions(builder);
+
+ auto op_offset =
+ CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions::BuiltinOptions_TransposeOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLTransposeConv *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_TRANSPOSE_CONV);
+
+ // Make input, output and options for operator
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->inputSizes()),
+ get_tensor_index(node->filter()),
+ get_tensor_index(node->outBackprop())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ circle::Padding padding = getOpPadding(node->padding());
+ auto options =
+ CreateTransposeConvOptions(builder, padding, node->stride()->w(), node->stride()->h());
+
+ // Make TRANSPOSE_CONV operator
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_TransposeConvOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+template <class TFLPool2D>
+void OperationExporter::export_pool_2d(TFLPool2D *node, circle::BuiltinOperator builtin_op)
+{
+ EXO_ASSERT(builtin_op == circle::BuiltinOperator_MAX_POOL_2D ||
+ builtin_op == circle::BuiltinOperator_AVERAGE_POOL_2D,
+ "should be maxpool or avgpool");
+ EXO_ASSERT(node->padding() != locoex::Padding::UNDEFINED, "Padding is not set");
+
+ uint32_t op_idx = gd.registerBuiltinOpcode(builtin_op);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->value())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+
+ circle::Padding padding = getOpPadding(node->padding());
+
+ auto options = CreatePool2DOptions(builder, padding, node->stride()->w(), node->stride()->h(),
+ node->filter()->w(), node->filter()->h(),
+ to_circle_actfunc(node->fusedActivationFunction()));
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_Pool2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::CircleInstanceNorm *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_INSTANCE_NORM);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->gamma()),
+ get_tensor_index(node->beta())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateInstanceNormOptions(builder, node->epsilon(),
+ to_circle_actfunc(node->fusedActivationFunction()));
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_InstanceNormOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::ReLU *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_RELU);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::ReLU6 *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_RELU6);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::Tanh *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_TANH);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::MaxPool2D *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MAX_POOL_2D);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ circle::Padding padding = getOpPadding(
+ node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node));
+ auto options = CreatePool2DOptions(builder, padding, node->stride()->horizontal(),
+ node->stride()->vertical(), node->window()->horizontal(),
+ node->window()->vertical());
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_Pool2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::AvgPool2D *node)
+{
+ // Circle only support Valid convention of average pooling
+ assert(node->convention() == loco::AvgPool2D::Convention::Valid);
+
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_AVERAGE_POOL_2D);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ circle::Padding padding = getOpPadding(
+ node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node));
+ auto options = CreatePool2DOptions(builder, padding, node->stride()->horizontal(),
+ node->stride()->vertical(), node->window()->horizontal(),
+ node->window()->vertical());
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_Pool2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::Conv2D *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_CONV_2D);
+
+ // Third input of CONV_2D of Circle should be bias. We will make (and register to gd) dummy zero
+ // bias. Bias would be rank 1, have size of output kernel count, and have all zero values, i.e.
+ // zero bias.
+ auto *ker = dynamic_cast<loco::FilterEncode *>(node->ker());
+ assert(ker);
+ int32_t bias_vec_size = ShapeInference::get(ker)._dims[0]; // output kernel count
+
+ auto bias_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{bias_vec_size});
+ size_t raw_bias_vec_size = bias_vec_size * sizeof(int32_t);
+
+ std::vector<float> bias_vec_data(bias_vec_size); // initialized as zero vector
+
+ auto bias_vec_offset =
+ builder.CreateVector(reinterpret_cast<uint8_t *>(bias_vec_data.data()), raw_bias_vec_size);
+
+ auto bias_buffer_offset = CreateBuffer(builder, bias_vec_offset);
+
+ const auto bias_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+
+ gd._buffers.push_back(bias_buffer_offset);
+
+ auto bias_tensor_id = static_cast<int32_t>(gd._tensors.size());
+ auto name_offset = builder.CreateString("t_" + std::to_string(bias_tensor_id));
+
+ auto bias_tensor_offset =
+ CreateTensor(builder, bias_vec_shape_offset, TensorType_FLOAT32, bias_buffer_id, name_offset);
+ gd._tensors.push_back(bias_tensor_offset);
+
+ // Make input, output and options for operator
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm()), get_tensor_index(node->ker()),
+ bias_tensor_id};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ circle::Padding padding = getOpPadding(
+ node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node));
+ auto options = CreateConv2DOptions(builder, padding, node->stride()->horizontal(),
+ node->stride()->vertical());
+
+ // Make CONV_2D operator
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_Conv2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::TransposedConv2D *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_TRANSPOSE_CONV);
+
+ // TRANSPOSE_CONV's first input is output shape array.
+ const int32_t outshape_vec_size = 4;
+ auto outshape_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{outshape_vec_size});
+ size_t raw_outshape_vec_size = outshape_vec_size * sizeof(int32_t);
+
+ std::vector<int32_t> outshape_vec_data(outshape_vec_size);
+ {
+ // Copy inferred output shape of node
+ auto out_feature_shape = loco::shape_get(node).as<loco::FeatureShape>();
+
+ // Feature tensor in Circle is NHWC
+ outshape_vec_data.at(0) = out_feature_shape.count().value();
+ outshape_vec_data.at(1) = out_feature_shape.height().value();
+ outshape_vec_data.at(2) = out_feature_shape.width().value();
+ outshape_vec_data.at(3) = out_feature_shape.depth().value();
+ }
+
+ auto outshape_vec_offset = builder.CreateVector(
+ reinterpret_cast<uint8_t *>(outshape_vec_data.data()), raw_outshape_vec_size);
+
+ auto outshape_buffer_offset = CreateBuffer(builder, outshape_vec_offset);
+
+ const auto outshape_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+
+ gd._buffers.push_back(outshape_buffer_offset);
+
+ auto outshape_tensor_id = static_cast<int32_t>(gd._tensors.size());
+ auto name_offset = builder.CreateString("t_" + std::to_string(outshape_tensor_id));
+
+ auto outshape_tensor_offset = CreateTensor(builder, outshape_vec_shape_offset, TensorType_INT32,
+ outshape_buffer_id, name_offset);
+ gd._tensors.push_back(outshape_tensor_offset);
+
+ // Make input, output and options for operator
+ std::vector<int32_t> inputs_vec{outshape_tensor_id, get_tensor_index(node->ker()),
+ get_tensor_index(node->ifm())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ // NOTE input and output is inversed to use this function
+ circle::Padding padding = getOpPadding(node->pad(), node->stride(), ShapeInference::get(node),
+ ShapeInference::get(node->ifm()));
+ auto options = CreateTransposeConvOptions(builder, padding, node->stride()->horizontal(),
+ node->stride()->vertical());
+
+ // Make TRANSPOSE_CONV operator
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_TransposeConvOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::DepthwiseConv2D *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_DEPTHWISE_CONV_2D);
+
+ // Third input of DEPTHWISE_CONV2D of Circle should be bias. We will make (and register to gd)
+ // dummy zero bias. Bias would be rank 1, have size of output kernel count, and have all zero
+ // values, i.e. zero bias.
+ auto *ker = dynamic_cast<loco::DepthwiseFilterEncode *>(node->ker());
+ assert(ker);
+
+ int32_t bias_vec_size = ShapeInference::get(ker)._dims[3]; // output_size(C*M)
+ auto bias_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{bias_vec_size});
+
+ size_t raw_bias_vec_size = bias_vec_size * sizeof(int32_t);
+ std::vector<float> bias_vec_data(bias_vec_size);
+ auto bias_vec_offset =
+ builder.CreateVector(reinterpret_cast<uint8_t *>(bias_vec_data.data()), raw_bias_vec_size);
+
+ auto bias_buffer_offset = CreateBuffer(builder, bias_vec_offset);
+
+ const auto bias_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+
+ gd._buffers.push_back(bias_buffer_offset);
+
+ auto bias_tensor_id = static_cast<int32_t>(gd._tensors.size());
+ auto name_offset = builder.CreateString("t_" + std::to_string(bias_tensor_id));
+
+ auto bias_tensor_offset =
+ CreateTensor(builder, bias_vec_shape_offset, TensorType_FLOAT32, bias_buffer_id, name_offset);
+ gd._tensors.push_back(bias_tensor_offset);
+
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm()), get_tensor_index(node->ker()),
+ bias_tensor_id};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ circle::Padding padding = getOpPadding(
+ node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node));
+
+ int32_t ifm_channel_size = ShapeInference::get(node->ifm())._dims[3];
+ // multiplier = bias_vec_size(output_size)/ifm_channel_size
+ auto options =
+ CreateDepthwiseConv2DOptions(builder, padding, node->stride()->horizontal(),
+ node->stride()->vertical(), bias_vec_size / ifm_channel_size);
+
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_DepthwiseConv2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::TensorReduce *node)
+{
+ uint32_t op_idx;
+
+ switch (node->func())
+ {
+ case loco::ReduceFunc::Mean:
+ op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MEAN);
+ break;
+
+ // TODO Support more reduce type operation
+ default:
+ INTERNAL_EXN_V("Unsupported reduce type", oops::to_uint32(node->func()));
+ }
+
+ // Create a vector for axes data
+ std::vector<int32_t> axes_vec;
+ auto rank = ShapeInference::get(node->input())._dims.size();
+ for (uint32_t i = 0; i < rank; ++i)
+ if (node->axes()->defined(i))
+ axes_vec.push_back(i);
+
+ int32_t axes_vec_size = axes_vec.size();
+ auto axes_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{axes_vec_size});
+
+ size_t raw_axes_vec_size = axes_vec_size * sizeof(int32_t);
+ auto axes_vec_offset =
+ builder.CreateVector(reinterpret_cast<uint8_t *>(axes_vec.data()), raw_axes_vec_size);
+
+ auto axes_buffer_offset = CreateBuffer(builder, axes_vec_offset);
+
+ const auto axes_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+
+ gd._buffers.push_back(axes_buffer_offset);
+
+ auto axes_tensor_id = static_cast<int32_t>(gd._tensors.size());
+ auto name_offset = builder.CreateString("t_" + std::to_string(axes_tensor_id));
+
+ auto axes_tensor_offset =
+ CreateTensor(builder, axes_vec_shape_offset, TensorType_INT32, axes_buffer_id, name_offset);
+ gd._tensors.push_back(axes_tensor_offset);
+
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), axes_tensor_id};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateReducerOptions(builder, true); // true is for keep_dims option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_ReducerOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::TensorSoftmax *node)
+{
+ // TODO Support when the input rank of TensorSoftmax is not 2
+ assert(ShapeInference::get(node->input())._dims.size() == 2);
+
+ // NOTE Circle only accepts axis when the value is last dimension
+ assert(node->axis() == ShapeInference::get(node->input())._dims.size() - 1);
+
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_SOFTMAX);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateSoftmaxOptions(builder, 1.0f);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_SoftmaxOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+/// @brief Export given node into identity, i.e. CONCATENATION with one input
+template <typename NodeT>
+void exportIdentity(NodeT *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_CONCATENATION);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0))};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateConcatenationOptions(builder); // use dummy 0 axis and NONE activation
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_ConcatenationOptions, options.Union());
+
+ gd._operators.push_back(op_offset);
+}
+
+/// @brief Export loco nodes as TRANSPOSE
+void exportAsTranspose(loco::Node *node, FlatBufferBuilder &builder,
+ std::vector<int32_t> &perm_vec_data, SerializedModelData &gd)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_TRANSPOSE);
+
+ auto options = CreateTransposeOptions(builder);
+
+ // Create constant tensor with perm vector
+ constexpr int perm_vec_size = 4;
+ assert(perm_vec_data.size() == perm_vec_size);
+ auto perm_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{perm_vec_size});
+ constexpr size_t raw_perm_vec_size = perm_vec_size * sizeof(int32_t);
+
+ auto perm_vec_offset =
+ builder.CreateVector(reinterpret_cast<uint8_t *>(perm_vec_data.data()), raw_perm_vec_size);
+
+ auto perm_buffer_offset = CreateBuffer(builder, perm_vec_offset);
+
+ const auto perm_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+
+ gd._buffers.push_back(perm_buffer_offset);
+
+ auto perm_tensor_id = static_cast<int32_t>(gd._tensors.size());
+ auto name_offset = builder.CreateString("t_" + std::to_string(perm_tensor_id));
+
+ auto perm_tensor_offset =
+ CreateTensor(builder, perm_vec_shape_offset, TensorType_INT32, perm_buffer_id, name_offset);
+ gd._tensors.push_back(perm_tensor_offset);
+
+ // Create permutation node
+
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0)), perm_tensor_id};
+ std::vector<int32_t> outputs_vec{get_tensor_index(node)};
+
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+
+ constexpr auto options_type = circle::BuiltinOptions::BuiltinOptions_TransposeOptions;
+
+ auto transpose_offset =
+ CreateOperator(builder, op_idx, inputs, outputs, options_type, options.Union());
+ gd._operators.push_back(transpose_offset);
+}
+
+void OperationExporter::visit(loco::FeatureEncode *node)
+{
+ auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Feature> *>(node->encoder());
+ auto perm = encoder->perm();
+
+ if (isNHWC(perm))
+ {
+ // Note that Circle represents feature as NHWC
+ exportIdentity(node, builder, gd);
+ }
+ else
+ {
+ std::vector<int32_t> perm_vec_data(4);
+ perm_vec_data[0] = perm->axis(loco::FeatureAxis::Count);
+ perm_vec_data[1] = perm->axis(loco::FeatureAxis::Height);
+ perm_vec_data[2] = perm->axis(loco::FeatureAxis::Width);
+ perm_vec_data[3] = perm->axis(loco::FeatureAxis::Depth);
+
+ exportAsTranspose(node, builder, perm_vec_data, gd);
+ }
+}
+
+void OperationExporter::visit(loco::FeatureDecode *node)
+{
+ auto decoder = dynamic_cast<loco::PermutingDecoder<loco::Domain::Feature> *>(node->decoder());
+ auto perm = decoder->perm();
+
+ if (isNHWC(perm))
+ {
+ // Note that Circle represents feature as NHWC
+ exportIdentity(node, builder, gd);
+ }
+ else
+ {
+ std::vector<int32_t> perm_vec_data(4);
+ perm_vec_data[perm->axis(loco::FeatureAxis::Count)] = 0;
+ perm_vec_data[perm->axis(loco::FeatureAxis::Height)] = 1;
+ perm_vec_data[perm->axis(loco::FeatureAxis::Width)] = 2;
+ perm_vec_data[perm->axis(loco::FeatureAxis::Depth)] = 3;
+
+ exportAsTranspose(node, builder, perm_vec_data, gd);
+ }
+}
+
+void OperationExporter::visit(loco::FilterEncode *node)
+{
+ auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Filter> *>(node->encoder());
+ auto perm = encoder->perm();
+
+ if (isNHWC(perm))
+ {
+ // Note that Circle represents filter as NHWC
+ exportIdentity(node, builder, gd);
+ }
+ else
+ {
+ std::vector<int32_t> perm_vec_data(4);
+ // NOTE In Circle, all tensors means NHWC, so 0 = N, 1 = H, 2 = W, 3 = C
+ perm_vec_data[0] = perm->axis(loco::FilterAxis::Count);
+ perm_vec_data[1] = perm->axis(loco::FilterAxis::Height);
+ perm_vec_data[2] = perm->axis(loco::FilterAxis::Width);
+ perm_vec_data[3] = perm->axis(loco::FilterAxis::Depth);
+
+ exportAsTranspose(node, builder, perm_vec_data, gd);
+ }
+}
+
+void exportAsReshape(loco::Node *node, FlatBufferBuilder &builder,
+ std::vector<int32_t> &new_shape_vec, SerializedModelData &gd)
+{
+ // NOTE Circle currently follows TFLite for this.
+ // NOTE TFLite has two ways to get new shape paramter,
+ // one is by attribute 'new_shape' and the other is by input 'shape'.
+ // Therefore TFLite interpreter calculates Reshape operation correctly
+ // if one of them is valid.
+ // However, since NN runtime usually get new shape parameter by input 'shape',
+ // passing new shape only by attribute can cause some problems.
+ // Of course, the opposite situation can be occurred in the future.
+ // To prevent those problems, we pass new shape parameter not only by attribute
+ // but also by input.
+
+ auto input_shape_shape_vec_offset =
+ builder.CreateVector(std::vector<int32_t>{(int32_t)new_shape_vec.size()});
+
+ size_t input_shape_vec_size = new_shape_vec.size() * sizeof(int32_t);
+ auto input_shape_input_vec_offset =
+ builder.CreateVector(reinterpret_cast<uint8_t *>(new_shape_vec.data()), input_shape_vec_size);
+ auto input_shape_buffer_offset = CreateBuffer(builder, input_shape_input_vec_offset);
+
+ const auto input_shape_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+ gd._buffers.push_back(input_shape_buffer_offset);
+
+ auto input_shape_tensor_id = static_cast<int32_t>(gd._tensors.size());
+ auto name_offset = builder.CreateString("t_" + std::to_string(input_shape_tensor_id));
+ auto input_shape_tensor_offset = CreateTensor(
+ builder, input_shape_shape_vec_offset, TensorType_INT32, input_shape_buffer_id, name_offset);
+ gd._tensors.push_back(input_shape_tensor_offset);
+
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_RESHAPE);
+
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0)), input_shape_tensor_id};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+
+ auto new_shape_vec_offset = builder.CreateVector(new_shape_vec);
+ auto options = CreateReshapeOptions(builder, new_shape_vec_offset);
+
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_ReshapeOptions, options.Union());
+
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::DepthwiseFilterEncode *node)
+{
+ auto ker = node->input(); // [H, W, C, M]
+
+ // Circle represents filter as [1, H, W, C*M] where M is multiplier.
+ std::vector<int32_t> new_shape_vec(4);
+ new_shape_vec[0] = 1;
+ new_shape_vec[1] = ShapeInference::get(ker)._dims[0];
+ new_shape_vec[2] = ShapeInference::get(ker)._dims[1];
+ new_shape_vec[3] = ShapeInference::get(ker)._dims[2] * ShapeInference::get(ker)._dims[3];
+
+ exportAsReshape(node, builder, new_shape_vec, gd);
+}
+
+void OperationExporter::visit(loco::BiasAdd<loco::Domain::Tensor> *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_ADD);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->value()), get_tensor_index(node->bias())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateAddOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_AddOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::FeatureBiasAdd *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_ADD);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->value()), get_tensor_index(node->bias())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateAddOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_AddOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+/// @brief Export CONCATENATION of **TWO** tensors only
+void OperationExporter::visit(loco::TensorConcat *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_CONCATENATION);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateConcatenationOptions(builder, node->axis());
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_ConcatenationOptions, options.Union());
+
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::BiasEncode *encode) { exportIdentity(encode, builder, gd); }
+
+void OperationExporter::visit(loco::EltwiseAdd *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_ADD);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateAddOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_AddOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::EltwiseMax *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MAXIMUM);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateMaximumMinimumOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_MaximumMinimumOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::EltwiseMul *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_MUL);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateMulOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_MulOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::EltwiseSub *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_SUB);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateSubOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_SubOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::EltwiseDiv *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_DIV);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateDivOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_DivOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::EltwiseSqrt *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_SQRT);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::FixedReshape *node)
+{
+ std::vector<int32_t> new_shape_vec;
+ for (uint32_t axis = 0; axis < node->rank(); ++axis)
+ {
+ assert(node->dim(axis).known());
+ new_shape_vec.push_back(node->dim(axis).value());
+ }
+
+ exportAsReshape(node, builder, new_shape_vec, gd);
+}
+
+void OperationExporter::visit(loco::TensorBroadcast *)
+{
+ INTERNAL_EXN("loco graph has loco::TensorBroadcast, which should not exist in the graph");
+}
+
+void OperationExporter::visit(loco::TensorConstantPad *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_PAD);
+
+ // make padding attribute an input
+ auto padding = node->padding();
+ // get padding vector size
+ int32_t padding_vec_size = padding->rank();
+ // get byte size of vector
+ size_t padding_vec_byte_size = padding_vec_size * sizeof(int32_t) * 2; // [rank, 2]
+ // create vector for data
+ std::vector<int32_t> padding_vec_data(padding_vec_size * 2);
+ // set data
+ for (int32_t i = 0; i < padding_vec_size; i++)
+ {
+ padding_vec_data.at(i * 2) = padding->front(i);
+ padding_vec_data.at(i * 2 + 1) = padding->back(i);
+ }
+ // create FlatBuffer vector
+ auto padding_vec_ptr = builder.CreateVector(reinterpret_cast<uint8_t *>(padding_vec_data.data()),
+ padding_vec_byte_size);
+
+ // create buffer
+ auto padding_buffer_ptr = CreateBuffer(builder, padding_vec_ptr);
+ // get buffer id
+ const auto padding_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+
+ gd._buffers.push_back(padding_buffer_ptr);
+
+ // create padding shape vector
+ auto padding_shape_vec_ptr = builder.CreateVector(std::vector<int32_t>{padding_vec_size, 2});
+ // create tensor
+ auto padding_tensor_ptr =
+ CreateTensor(builder, padding_shape_vec_ptr, TensorType_INT32, padding_buffer_id);
+ // get tensor id
+ const auto padding_tensor_id = static_cast<int32_t>(gd._tensors.size());
+
+ gd._tensors.push_back(padding_tensor_ptr);
+
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), padding_tensor_id};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+inline flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
+CreateCOpCallOptions(flatbuffers::FlatBufferBuilder &fbb, locoex::COpCall *copCall)
+{
+ // read attrs in FlexBuffer format and pass them to FlatBuffer builder
+ flexbuffers::Builder flexbuf;
+ {
+ size_t map_start = flexbuf.StartMap();
+
+ // Note: among attrs of COpCall, 'op' and 'name' won't be included into tflite file
+ auto names = copCall->attr_names();
+ for (auto name : names)
+ {
+ if (auto int_val = copCall->attr<locoex::COpAttrType::Int>(name))
+ flexbuf.Int(name.c_str(), int_val->val());
+ else if (auto float_val = copCall->attr<locoex::COpAttrType::Float>(name))
+ flexbuf.Float(name.c_str(), float_val->val());
+ else
+ // TODO Support more attribute types
+ INTERNAL_EXN_V("Unsupported dtype while writing flexbuffer for customop attr", name);
+ }
+
+ flexbuf.EndMap(map_start);
+ flexbuf.Finish();
+ }
+
+ auto offset = fbb.CreateVector(flexbuf.GetBuffer());
+
+ return offset;
+}
+
+void OperationExporter::visit(locoex::COpCall *call)
+{
+ // Registering this custom op name into tflite Operator Codes table
+ uint32_t op_idx = gd.registerCustomOpcode(call->op());
+
+ std::vector<int32_t> inputs_vec;
+ {
+ inputs_vec.resize(call->arity());
+ for (uint32_t i = 0; i < call->arity(); i++)
+ inputs_vec[i] = get_tensor_index(call->arg(i));
+ }
+
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(call))};
+
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+
+ auto custom_options = CreateCOpCallOptions(builder, call);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_NONE, // builtin_options_type
+ 0, // built-in option
+ custom_options, // custom options
+ circle::CustomOptionsFormat_FLEXBUFFERS);
+
+ gd._operators.push_back(op_offset);
+}
+
+void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder,
+ SerializedModelData &data)
+{
+ // TODO Use explicit tagging to prevent possible mistake
+ auto isNoOp = [](loco::Node *node) {
+ if (node->arity() == 1)
+ {
+ assert(node->arg(0) != nullptr);
+ return get_tensor_index(node) == get_tensor_index(node->arg(0));
+ }
+ return false;
+ };
+
+ if (isNoOp(node))
+ {
+ // Skip if a given node is marked as NoOp (op with no effect) before
+ return;
+ }
+
+ if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node))
+ { // TODO Consider removing this later
+ OperationExporter exporter{builder, data};
+ canonical_node->accept(&exporter);
+ }
+ else if (auto tfl_node = dynamic_cast<locoex::TFLNode *>(node))
+ {
+ OperationExporter exporter{builder, data};
+ tfl_node->accept(&exporter);
+ }
+ else if (auto circle_node = dynamic_cast<locoex::CircleNode *>(node))
+ {
+ OperationExporter exporter{builder, data};
+ circle_node->accept(&exporter);
+ }
+ else if (dynamic_cast<locoex::COpNode *>(node))
+ {
+ OperationExporter exporter{builder, data};
+ exporter.visit(dynamic_cast<locoex::COpCall *>(node));
+ }
+ else
+ {
+ INTERNAL_EXN("Node with unsupported dialect found");
+ }
+}
+
+} // namespace
+
+namespace exo
+{
+namespace circle_detail
+{
+
+void exportNodes(loco::Graph *g, FlatBufferBuilder &builder, SerializedModelData &gd)
+{
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ exportNode(node, builder, gd);
+ }
+}
+
+} // namespace circle_detail
+} // namespace exo
diff --git a/compiler/exo/src/Circle/CircleOperationExporter.h b/compiler/exo/src/Circle/CircleOperationExporter.h
new file mode 100644
index 000000000..19dadbfd1
--- /dev/null
+++ b/compiler/exo/src/Circle/CircleOperationExporter.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_OPERATION_EXPORTER_H__
+#define __CIRCLE_OPERATION_EXPORTER_H__
+
+#include "CircleExporterUtils.h"
+
+#include <loco/IR/Graph.h>
+
+namespace exo
+{
+namespace circle_detail
+{
+
+/**
+ * @brief create Operators corresponding to model nodes
+ * @param nodes container with nodes
+ * @param gd information about serializer parts of model
+ */
+void exportNodes(loco::Graph *g, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &gd);
+
+} // namespace circle_detail
+} // namespace exo
+
+#endif // __CIRCLE_OPERATION_EXPORTER_H__
diff --git a/compiler/exo/src/Circle/CircleTensorExporter.cpp b/compiler/exo/src/Circle/CircleTensorExporter.cpp
new file mode 100644
index 000000000..efceae55d
--- /dev/null
+++ b/compiler/exo/src/Circle/CircleTensorExporter.cpp
@@ -0,0 +1,261 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleTensorExporter.h"
+#include "CircleTypeInference.h"
+#include "ShapeInference.h"
+
+// TODO Fix include style
+#include "loco/IR/Algorithm.h"
+#include "loco/IR/CanonicalNode.h"
+#include "loco/IR/CanonicalNodeVisitor.h"
+#include "loco/IR/DataTypeTraits.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include <oops/InternalExn.h>
+
+using namespace circle;
+using namespace flatbuffers;
+
+namespace
+{
+
+using namespace exo;
+using namespace exo::circle_detail;
+
+class TFLTensorInfo
+{
+public:
+ TFLTensorInfo() = default;
+
+public:
+ void name(const std::string &name) { _name = name; }
+ const std::string &name(void) const { return _name; }
+
+public:
+ const circle::TensorType &dtype(void) const { return _dtype; }
+ void dtype(const circle::TensorType &dtype) { _dtype = dtype; }
+
+ const ShapeDescription &shape(void) const { return _shape; }
+ void shape(const ShapeDescription &shape) { _shape = shape; }
+
+public:
+ locoex::TFLConst *tfl_content(void) const { return _tfl_content; }
+ void tfl_content(locoex::TFLConst *c) { _tfl_content = c; }
+
+private:
+ std::string _name;
+
+ circle::TensorType _dtype;
+ ShapeDescription _shape;
+
+ // TODO Find a better design
+ loco::ConstGen *_content = nullptr; // TODO deprecate
+ locoex::TFLConst *_tfl_content = nullptr;
+};
+
+using TFLTensorContext = std::vector<TFLTensorInfo>;
+
+struct NoOpDetector final : public loco::CanonicalNodeMutableVisitor<bool>
+{
+ bool visit(loco::BiasEncode *) final
+ {
+ // BiasEncode is always noop
+ return true;
+ }
+
+ bool visit(loco::FilterEncode *node) final
+ {
+ auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Filter> *>(node->encoder());
+ if (encoder != nullptr)
+ {
+ auto perm = encoder->perm();
+ return isNHWC(perm);
+ }
+ return false;
+ }
+
+ bool visit(loco::FeatureEncode *node) final
+ {
+ auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Feature> *>(node->encoder());
+ if (encoder != nullptr)
+ {
+ auto perm = encoder->perm();
+ return isNHWC(perm);
+ }
+ return false;
+ }
+
+ bool visit(loco::FeatureDecode *node) final
+ {
+ auto decoder = dynamic_cast<loco::PermutingDecoder<loco::Domain::Feature> *>(node->decoder());
+ if (decoder != nullptr)
+ {
+ auto perm = decoder->perm();
+ return isNHWC(perm);
+ }
+ return false;
+ }
+
+ // Return false by default
+ bool visit(loco::Node *) final { return false; }
+};
+
+bool isNoOp(loco::Node *node)
+{
+ if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node))
+ {
+ NoOpDetector d;
+ return canonical_node->accept(&d);
+ }
+ return false;
+}
+
+void allocateCircleTensor(loco::Node *node, TFLTensorContext &ctx)
+{
+ if (isNoOp(node))
+ {
+ assert(node->arity() == 1 && node->arg(0) != nullptr);
+ set_tensor_index(node, get_tensor_index(node->arg(0)));
+ return;
+ }
+
+ auto tensor_index = static_cast<TFLTensorIndex>(ctx.size());
+ // TODO Use Graph-level metadata for Input & Output
+ auto tensor_name = "t_" + std::to_string(tensor_index);
+
+ TFLTensorInfo tensor_info;
+
+ tensor_info.name(tensor_name);
+ tensor_info.dtype(TypeInference::get(node));
+ tensor_info.shape(ShapeInference::get(node));
+
+ tensor_info.tfl_content(dynamic_cast<locoex::TFLConst *>(node));
+
+ set_tensor_index(node, tensor_index);
+
+ ctx.emplace_back(tensor_info);
+}
+
+} // namespace
+
+namespace
+{
+
+flatbuffers::Offset<Vector<int32_t>> encodeShape(FlatBufferBuilder &builder,
+ const ShapeDescription &shape)
+{
+ assert(shape._rank_known && "unknown number of dimensions is not supported");
+ return builder.CreateVector(shape._dims);
+}
+
+flatbuffers::Offset<circle::Buffer> encodeOpBuffer(FlatBufferBuilder &builder)
+{
+ return CreateBuffer(builder);
+}
+
+template <typename NodeT>
+flatbuffers::Offset<circle::Buffer> encodeOpBuffer(FlatBufferBuilder &builder, NodeT *)
+{
+ return CreateBuffer(builder);
+}
+
+template <loco::DataType DT>
+flatbuffers::Offset<circle::Buffer> encodeOpBufferByDType(FlatBufferBuilder &builder,
+ locoex::TFLConst *c)
+{
+ using NativeType = typename loco::DataTypeImpl<DT>::Type;
+
+ std::vector<NativeType> raw_data;
+ const uint32_t size = c->size<DT>();
+ raw_data.reserve(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ raw_data.push_back(c->at<DT>(i));
+ }
+ const size_t raw_size = size * sizeof(NativeType);
+ auto array_offset = builder.CreateVector(reinterpret_cast<uint8_t *>(raw_data.data()), raw_size);
+ return CreateBuffer(builder, array_offset);
+}
+
+template <>
+flatbuffers::Offset<circle::Buffer> encodeOpBuffer(FlatBufferBuilder &builder, locoex::TFLConst *c)
+{
+ if (c->dtype() == loco::DataType::FLOAT32)
+ {
+ return encodeOpBufferByDType<loco::DataType::FLOAT32>(builder, c);
+ }
+ else if (c->dtype() == loco::DataType::S32)
+ {
+ return encodeOpBufferByDType<loco::DataType::S32>(builder, c);
+ }
+
+ INTERNAL_EXN_V("Unsupported datatype", oops::to_uint32(c->dtype()));
+}
+
+} // namespace
+
+namespace exo
+{
+namespace circle_detail
+{
+
+void exportOpDefinedTensor(const TFLTensorInfo &info, FlatBufferBuilder &builder,
+ SerializedModelData &gd)
+{
+ // Create and register output tensor shape
+ auto shape_offset = encodeShape(builder, info.shape());
+
+ // encode and register output tensor buffer
+ auto buffer = info.tfl_content() == nullptr ? encodeOpBuffer(builder)
+ : encodeOpBuffer(builder, info.tfl_content());
+
+ auto buffer_id = static_cast<uint32_t>(gd._buffers.size());
+ gd._buffers.push_back(buffer);
+
+ auto name_offset = builder.CreateString(info.name());
+ auto tensor_offset = CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset,
+ /*quantization*/ 0, /*is_variable*/ false);
+ gd._tensors.push_back(tensor_offset);
+}
+
+void exportOpDefinedTensors(loco::Graph *g, FlatBufferBuilder &builder, SerializedModelData &gd)
+{
+ TFLTensorContext tensor_ctx;
+
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ allocateCircleTensor(node, tensor_ctx);
+ }
+
+ // add one empty buffer
+ // note: this follows TFLite
+ // note: there's a comment in tflite fbs file
+ // - Note the 0th entry of this array must be an empty buffer (sentinel).
+ // - This is a convention so that tensors without a buffer can provide 0 as
+ // - their buffer.
+ auto buffer = encodeOpBuffer(builder);
+ gd._buffers.push_back(buffer);
+
+ for (const auto &tensor_info : tensor_ctx)
+ {
+ exportOpDefinedTensor(tensor_info, builder, gd);
+ }
+}
+
+} // namespace circle_detail
+} // namespace exo
diff --git a/compiler/exo/src/Circle/CircleTensorExporter.h b/compiler/exo/src/Circle/CircleTensorExporter.h
new file mode 100644
index 000000000..39d8e1b86
--- /dev/null
+++ b/compiler/exo/src/Circle/CircleTensorExporter.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_TENSOR_EXPORTER_H__
+#define __CIRCLE_TENSOR_EXPORTER_H__
+
+#include "CircleExporterUtils.h"
+
+#include <loco/IR/Graph.h>
+
+#include <flatbuffers/flatbuffers.h>
+
+namespace exo
+{
+namespace circle_detail
+{
+
+/**
+ * @brief create Tensors corresponding to results of all nodes in graph
+ * @param computational graph
+ * @param gd information about serialized parts of model
+ */
+void exportOpDefinedTensors(loco::Graph *g, flatbuffers::FlatBufferBuilder &builder,
+ SerializedModelData &gd);
+
+} // namespace circle_detail
+} // namespace exo
+
+#endif // __CIRCLE_TENSOR_EXPORTER_H__
diff --git a/compiler/exo/src/Circle/CircleTypeInference.cpp b/compiler/exo/src/Circle/CircleTypeInference.cpp
new file mode 100644
index 000000000..a1e92b884
--- /dev/null
+++ b/compiler/exo/src/Circle/CircleTypeInference.cpp
@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleTypeInference.h"
+
+#include "circle_schema_generated.h"
+
+#include "Dialect/Service/TFLTypeInferenceRule.h"
+#include "Dialect/IR/TFLDialect.h"
+
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/Service/TypeInference.h>
+
+#include <locoex/COpDialect.h>
+#include <locoex/Service/COpTypeInference.h>
+
+#include <oops/InternalExn.h>
+
+#include <stdex/Memory.h>
+
+#include <stdexcept>
+#include <type_traits>
+
+namespace
+{
+
+circle::TensorType translateLocoTypeToCircle(loco::DataType dtype)
+{
+ switch (dtype)
+ {
+ case loco::DataType::U8:
+ return circle::TensorType_UINT8;
+ // case loco::DataType::U16: unsupported
+ // case loco::DataType::U32: unsupported
+ // case loco::DataType::U64: unsupported
+ case loco::DataType::S8:
+ return circle::TensorType_INT8;
+ case loco::DataType::S16:
+ return circle::TensorType_INT16;
+ case loco::DataType::S32:
+ return circle::TensorType_INT32;
+ case loco::DataType::S64:
+ return circle::TensorType_INT64;
+ case loco::DataType::FLOAT16:
+ return circle::TensorType_FLOAT16;
+ case loco::DataType::FLOAT32:
+ return circle::TensorType_FLOAT32;
+ // case loco::DataType::FLOAT64: unsupported
+ default:
+ break;
+ }
+
+ INTERNAL_EXN_V("Invalid loco dtype", oops::to_uint32(dtype));
+}
+
+} // namespace
+
+namespace exo
+{
+namespace circle_detail
+{
+
+circle::TensorType TypeInference::get(loco::Node *node)
+{
+ assert(loco::dtype_known(node));
+ return translateLocoTypeToCircle(loco::dtype_get(node));
+}
+
+} // namespace circle_detail
+} // namespace exo
diff --git a/compiler/exo/src/Circle/CircleTypeInference.h b/compiler/exo/src/Circle/CircleTypeInference.h
new file mode 100644
index 000000000..9c1730233
--- /dev/null
+++ b/compiler/exo/src/Circle/CircleTypeInference.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_TYPE_INFERENCE_H__
+#define __CIRCLE_TYPE_INFERENCE_H__
+
+#include "CircleExporterUtils.h"
+
+#include <loco/IR/Nodes.h>
+
+namespace exo
+{
+namespace circle_detail
+{
+
+/**
+ * @brief Get the type of each node as NodeAnnotation
+ *
+ * HOW TO USE
+ *
+ * TypeInference::get(g->nodes()->at(0));
+ * TypeInference::get(g->nodes()->at(...));
+ */
+struct TypeInference
+{
+ static circle::TensorType get(loco::Node *node);
+};
+
+} // namespace circle_detail
+} // namespace exo
+
+#endif // __CIRCLE_TYPE_INFERENCE_H__
diff --git a/compiler/exo/src/Conversion/AvgPool2DConverter.cpp b/compiler/exo/src/Conversion/AvgPool2DConverter.cpp
new file mode 100644
index 000000000..a95518ac6
--- /dev/null
+++ b/compiler/exo/src/Conversion/AvgPool2DConverter.cpp
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "AvgPool2DConverter.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include "GraphBlock.h"
+#include "Check.h"
+
+#include <loco.h>
+
+namespace exo
+{
+/**
+ * @brief Converts loco::AvgPool2D to locoex::TFLAveragePool2D
+ *
+ * How it works: (note: ten->fea means input: tensor, output: feature)
+ *
+ * Before:
+ * Foo ---- FeatureEncode ---- AvgPool2D ---- FeatureDecode ---- Bar
+ * ten->ten ten->fea fea->fea fea->ten ten->ten
+ *
+ * After: AvgPool2D
+ * /
+ * Foo -- FeatureEncode - FeatureDecode - TFLAvgPool2D - FeatureEncode - FeatureDecode -- Bar
+ * ten->ten ten->fea fea->ten ten->ten ten->fea fea->ten ten->ten
+ *
+ * @note This method replaces AvgPool2D with "FeatureDecode -- TFLAvgPool2D -- FeatureEncode".
+ * Redundant nodes will be removed during transforms.
+ */
+bool AvgPool2DConverter::convert(loco::AvgPool2D *origin)
+{
+ auto *graph = origin->graph();
+
+ auto dec = make_feature_decode<FeatureLayout::NHWC>(origin->ifm());
+ auto tfl_average = graph->nodes()->create<locoex::TFLAveragePool2D>();
+ {
+ tfl_average->value(dec);
+
+ // set attributes
+ tfl_average->stride()->w(origin->stride()->horizontal());
+ tfl_average->stride()->h(origin->stride()->vertical());
+
+ tfl_average->filter()->w(origin->window()->horizontal());
+ tfl_average->filter()->h(origin->window()->vertical());
+
+ auto pad = origin->pad();
+ if (pad->bottom() == 0 && pad->top() == 0 && pad->left() == 0 && pad->right() == 0)
+ tfl_average->padding(locoex::Padding::VALID);
+ else
+ // TODO This is necessary, but not sufficient condition. More rigorous check required
+ tfl_average->padding(locoex::Padding::SAME);
+
+ tfl_average->fusedActivationFunction(locoex::FusedActFunc::NONE);
+ }
+ auto enc = make_feature_encode<FeatureLayout::NHWC>(tfl_average);
+
+ // replace canonical node
+ loco::replace(origin).with(enc);
+ origin->ifm(nullptr);
+
+ return true;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/AvgPool2DConverter.h b/compiler/exo/src/Conversion/AvgPool2DConverter.h
new file mode 100644
index 000000000..f66d02eb6
--- /dev/null
+++ b/compiler/exo/src/Conversion/AvgPool2DConverter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_AVGPOOL2D_CONVERTER__
+#define __CONVERSION_AVGPOOL2D_CONVERTER__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::AvgPool2D to locoex::TFLAveragePool2D
+ */
+class AvgPool2DConverter : public CanonicalNodeConverter<loco::AvgPool2D>
+{
+public:
+ const char *name(void) const final { return "exo::AvgPool2DConverter"; }
+
+public:
+ bool convert(loco::AvgPool2D *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_AVGPOOL2D_CONVERTER__
diff --git a/compiler/exo/src/Conversion/CanonicalNodeConverter.cpp b/compiler/exo/src/Conversion/CanonicalNodeConverter.cpp
new file mode 100644
index 000000000..4daf905f8
--- /dev/null
+++ b/compiler/exo/src/Conversion/CanonicalNodeConverter.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CanonicalNodeConverter.h"
+
+// This file is to make sure compilation of "CanonicalNodeConverter.h"
diff --git a/compiler/exo/src/Conversion/CanonicalNodeConverter.h b/compiler/exo/src/Conversion/CanonicalNodeConverter.h
new file mode 100644
index 000000000..76f73d888
--- /dev/null
+++ b/compiler/exo/src/Conversion/CanonicalNodeConverter.h
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_CANONICAL_NODE_CONVERTER_H__
+#define __CONVERSION_CANONICAL_NODE_CONVERTER_H__
+
+#include "Convert.h"
+
+#include <loco.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to convert a canonical node to TFL node
+ *
+ * TODO Find a better name
+ */
+template <typename CanonicalType> class CanonicalNodeConverter : public logo::Pass
+{
+public:
+ virtual const char *name(void) const { return nullptr; }
+
+public:
+ bool run(loco::Graph *graph);
+
+protected:
+ virtual bool convert(CanonicalType *node) = 0;
+};
+
+template <typename CanonicalType>
+bool CanonicalNodeConverter<CanonicalType>::run(loco::Graph *graph)
+{
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ // TODO Generalize this to all loco dialects
+ if (node->dialect() == loco::CanonicalDialect::get())
+ {
+ auto the_node = dynamic_cast<CanonicalType *>(node);
+ if (the_node != nullptr)
+ {
+ if (convert(the_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace exo
+
+#endif //__CONVERSION_CANONICAL_NODE_CONVERTER_H__
diff --git a/compiler/exo/src/Conversion/ConstGenConverter.cpp b/compiler/exo/src/Conversion/ConstGenConverter.cpp
new file mode 100644
index 000000000..b2e2b4bdb
--- /dev/null
+++ b/compiler/exo/src/Conversion/ConstGenConverter.cpp
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConstGenConverter.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Check.h"
+
+#include <loco.h>
+
+#include <oops/InternalExn.h>
+
+namespace exo
+{
+
+bool ConstGenConverter::convert(loco::ConstGen *constgen)
+{
+ auto *graph = constgen->graph();
+
+ auto tfl_const = graph->nodes()->create<locoex::TFLConst>();
+ {
+ if (constgen->dtype() == loco::DataType::FLOAT32)
+ {
+ tfl_const->dtype(loco::DataType::FLOAT32);
+
+ tfl_const->rank(constgen->rank());
+ for (uint32_t axis = 0; axis < constgen->rank(); axis++)
+ tfl_const->dim(axis) = constgen->dim(axis);
+
+ auto size = constgen->size<loco::DataType::FLOAT32>();
+ tfl_const->size<loco::DataType::FLOAT32>(size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ tfl_const->at<loco::DataType::FLOAT32>(i) = constgen->at<loco::DataType::FLOAT32>(i);
+ }
+ }
+ else
+ INTERNAL_EXN_V("Unsupported DataType", oops::to_uint32(constgen->dtype()));
+ }
+
+ loco::replace(constgen).with(tfl_const);
+
+ return true;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/ConstGenConverter.h b/compiler/exo/src/Conversion/ConstGenConverter.h
new file mode 100644
index 000000000..613ccd0e6
--- /dev/null
+++ b/compiler/exo/src/Conversion/ConstGenConverter.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_CONSTGEN_CONVERTER_H__
+#define __CONVERSION_CONSTGEN_CONVERTER_H__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+class ConstGenConverter : public CanonicalNodeConverter<loco::ConstGen>
+{
+public:
+ const char *name(void) const final { return "exo::ConstGenConverter"; }
+
+public:
+ bool convert(loco::ConstGen *constgen) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_CONSTGEN_CONVERTER_H__
diff --git a/compiler/exo/src/Conversion/ConstGenConverter.test.cpp b/compiler/exo/src/Conversion/ConstGenConverter.test.cpp
new file mode 100644
index 000000000..f7a577242
--- /dev/null
+++ b/compiler/exo/src/Conversion/ConstGenConverter.test.cpp
@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConstGenConverter.h"
+#include "ReluConverter.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "TestGraph.h"
+#include "TestHelper.h"
+
+#include <loco.h>
+
+#include <gtest/gtest.h>
+
+TEST(TFLConstGenConverterTest, ConstGen_Relu)
+{
+ exo::test::ExampleGraph<exo::test::ExampleGraphType::ConstGen_ReLU> g;
+
+ // set constgen
+ {
+ g.constgen->dtype(loco::DataType::FLOAT32);
+ g.constgen->shape({2, 1});
+ g.constgen->size<loco::DataType::FLOAT32>(2);
+
+ g.constgen->at<loco::DataType::FLOAT32>(0) = 0.5;
+ g.constgen->at<loco::DataType::FLOAT32>(1) = -0.5;
+ }
+
+ // let's convert
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::ConstGenConverter>();
+ test_phase.add_pass<exo::ReluConverter>();
+
+ test_phase.run(g.graph());
+ }
+
+ auto tfl_const = exo::test::find_first_node_bytype<locoex::TFLConst>(g.graph());
+ auto tfl_relu = exo::test::find_first_node_bytype<locoex::TFLRelu>(g.graph());
+
+ ASSERT_TRUE(tfl_const != nullptr and tfl_relu != nullptr);
+ ASSERT_TRUE(tfl_relu->features() == tfl_const);
+
+ ASSERT_TRUE(tfl_const->rank() == g.constgen->rank());
+ ASSERT_TRUE(tfl_const->dim(0) == g.constgen->dim(0));
+ ASSERT_TRUE(tfl_const->dim(1) == g.constgen->dim(1));
+ ASSERT_TRUE(tfl_const->at<loco::DataType::FLOAT32>(0) ==
+ g.constgen->at<loco::DataType::FLOAT32>(0));
+ ASSERT_TRUE(tfl_const->at<loco::DataType::FLOAT32>(1) ==
+ g.constgen->at<loco::DataType::FLOAT32>(1));
+}
diff --git a/compiler/exo/src/Conversion/Conv2DConverter.cpp b/compiler/exo/src/Conversion/Conv2DConverter.cpp
new file mode 100644
index 000000000..c8120171d
--- /dev/null
+++ b/compiler/exo/src/Conversion/Conv2DConverter.cpp
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Conv2DConverter.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include "GraphBlock.h"
+#include "Check.h"
+
+#include <loco.h>
+#include <loco/Service/TypeInference.h>
+#include <loco/Service/ShapeInference.h>
+
+namespace exo
+{
+/**
+ * @brief Converts loco::Conv2D to locoex::TFLConv2D
+ * @note Because TFLConv2D accepts input and filter of loco::Domain::Tensor,
+ * loco::FeatureDecode and loco::FilterDecode will be inserted as an inputs
+ * to meet domain invariant.
+ * Please refer to the comment in AvgPool2DConvert.
+ */
+bool Conv2DConverter::convert(loco::Conv2D *origin)
+{
+ auto *graph = origin->graph();
+
+ assert(origin->ifm());
+ assert(origin->ker());
+
+ auto tfl_conv2d = graph->nodes()->create<locoex::TFLConv2D>();
+ {
+ tfl_conv2d->stride()->w(origin->stride()->horizontal());
+ tfl_conv2d->stride()->h(origin->stride()->vertical());
+
+ auto pad = origin->pad();
+ if (pad->bottom() == 0 && pad->top() == 0 && pad->left() == 0 && pad->right() == 0)
+ tfl_conv2d->padding(locoex::Padding::VALID);
+ else
+ // TODO This is necessary, but not sufficient condition. More rigorous check required
+ tfl_conv2d->padding(locoex::Padding::SAME);
+
+ tfl_conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE);
+ }
+
+ // let's create a new graph connection with tfl_conv2d
+ {
+ // input
+ auto feature_dec = make_feature_decode<FeatureLayout::NHWC>(origin->ifm());
+ tfl_conv2d->input(feature_dec);
+
+ // filter
+ auto filter_dec = make_filter_decode<FilterLayout::OHWI>(origin->ker());
+ tfl_conv2d->filter(filter_dec);
+
+ // bias
+ auto zero_const = graph->nodes()->create<locoex::TFLConst>();
+ {
+ assert(loco::shape_known(origin));
+ assert(loco::dtype_known(origin) && loco::dtype_get(origin) == loco::DataType::FLOAT32);
+
+ auto output_depth = loco::shape_get(origin->ker()).as<loco::FilterShape>().count();
+
+ zero_const->dtype(loco::DataType::FLOAT32);
+ zero_const->rank(1);
+ zero_const->dim(0) = output_depth;
+ zero_const->size<loco::DataType::FLOAT32>(output_depth.value());
+ for (uint32_t x = 0; x < output_depth.value(); x++)
+ zero_const->at<loco::DataType::FLOAT32>(x) = 0.0;
+ }
+ tfl_conv2d->bias(zero_const);
+
+ // output
+ auto feature_enc = make_feature_encode<FeatureLayout::NHWC>(tfl_conv2d);
+
+ // replace canonical node
+ loco::replace(origin).with(feature_enc);
+ origin->ifm(nullptr);
+ }
+
+ return true;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/Conv2DConverter.h b/compiler/exo/src/Conversion/Conv2DConverter.h
new file mode 100644
index 000000000..95b3fbfae
--- /dev/null
+++ b/compiler/exo/src/Conversion/Conv2DConverter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_CONV2D_CONVERTER__
+#define __CONVERSION_CONV2D_CONVERTER__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::Conv2D to locoex::TFLConv2D
+ */
+class Conv2DConverter : public CanonicalNodeConverter<loco::Conv2D>
+{
+public:
+ const char *name(void) const final { return "exo::Conv2DConverter"; }
+
+public:
+ bool convert(loco::Conv2D *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_CONV2D_CONVERTER__
diff --git a/compiler/exo/src/Conversion/DepthwiseConv2DConverter.cpp b/compiler/exo/src/Conversion/DepthwiseConv2DConverter.cpp
new file mode 100644
index 000000000..5959fcc45
--- /dev/null
+++ b/compiler/exo/src/Conversion/DepthwiseConv2DConverter.cpp
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "DepthwiseConv2DConverter.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include "GraphBlock.h"
+#include "Check.h"
+
+#include <loco.h>
+#include <loco/Service/TypeInference.h>
+#include <loco/Service/ShapeInference.h>
+
+namespace exo
+{
+
+bool DepthwiseConv2DConverter::convert(loco::DepthwiseConv2D *origin)
+{
+ // Filter shape is required
+ if (not loco::shape_known(origin->ker()))
+ return false;
+
+ auto filter_shape = loco::shape_get(origin->ker()).as<loco::DepthwiseFilterShape>();
+
+ if ((origin->ifm() == nullptr) or (origin->ker() == nullptr))
+ return false;
+
+ auto *graph = origin->graph();
+
+ auto tfl_dw_conv2d = graph->nodes()->create<locoex::TFLDepthwiseConv2D>();
+ {
+ tfl_dw_conv2d->stride()->w(origin->stride()->horizontal());
+ tfl_dw_conv2d->stride()->h(origin->stride()->vertical());
+
+ auto pad = origin->pad();
+ if (pad->bottom() == 0 && pad->top() == 0 && pad->left() == 0 && pad->right() == 0)
+ tfl_dw_conv2d->padding(locoex::Padding::VALID);
+ else
+ // TODO This is necessary, but not sufficient condition. More rigorous check required
+ tfl_dw_conv2d->padding(locoex::Padding::SAME);
+
+ tfl_dw_conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE);
+
+ uint32_t multiplier = filter_shape.multiplier().value();
+ EXO_ASSERT(multiplier < std::numeric_limits<int32_t>::max(),
+ "Multiplier is too big that casting may occur unintended behavior")
+
+ tfl_dw_conv2d->depthMultiplier(static_cast<int32_t>(multiplier));
+ }
+
+ // let's create a new graph connection with tfl_dw_conv2d
+ {
+ // ifm --- feature_dec --- tfl_dw_conv2d
+ auto feature_dec = make_feature_decode<FeatureLayout::NHWC>(origin->ifm());
+ tfl_dw_conv2d->input(feature_dec);
+
+ // ker --- filter_dec(H x W x C x M) --- reshape(1 x H x W x CM) --- tfl_dw_conv2d
+ auto filter_dec = make_dw_filter_decode<DepthwiseFilterLayout::HWCM>(origin->ker());
+
+ auto reshape = graph->nodes()->create<locoex::TFLReshape>();
+ reshape->tensor(filter_dec);
+
+ int32_t new_shape[4] = {
+ 1, static_cast<int32_t>(filter_shape.height().value()),
+ static_cast<int32_t>(filter_shape.width().value()),
+ static_cast<int32_t>(filter_shape.depth().value() * filter_shape.multiplier().value())};
+ locoex::set_new_shape(reshape, new_shape, 4);
+
+ tfl_dw_conv2d->filter(reshape);
+
+ // bias
+ auto zero_const = graph->nodes()->create<locoex::TFLConst>();
+ {
+ assert(loco::shape_known(origin));
+ assert(loco::dtype_known(origin) && loco::dtype_get(origin) == loco::DataType::FLOAT32);
+
+ // bias size is C * M
+ uint32_t bias_size = filter_shape.depth().value() * filter_shape.multiplier().value();
+
+ zero_const->dtype(loco::DataType::FLOAT32);
+ zero_const->rank(1);
+ zero_const->dim(0) = bias_size;
+ zero_const->size<loco::DataType::FLOAT32>(bias_size);
+ for (uint32_t x = 0; x < bias_size; x++)
+ zero_const->at<loco::DataType::FLOAT32>(x) = 0.0;
+ }
+ tfl_dw_conv2d->bias(zero_const);
+
+ // output
+ auto feature_enc = make_feature_encode<FeatureLayout::NHWC>(tfl_dw_conv2d);
+
+ // replace canonical node
+ loco::replace(origin).with(feature_enc);
+ origin->ifm(nullptr);
+ }
+
+ return true;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/DepthwiseConv2DConverter.h b/compiler/exo/src/Conversion/DepthwiseConv2DConverter.h
new file mode 100644
index 000000000..57cc01e5e
--- /dev/null
+++ b/compiler/exo/src/Conversion/DepthwiseConv2DConverter.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_DEPTHWISECONV2D_CONVERTER__
+#define __CONVERSION_DEPTHWISECONV2D_CONVERTER__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::DepthwiseConv2D to locoex::TFLDepthwiseConv2D and auxiliary
+ *
+ *
+ * <BEFORE>
+ *
+ * IFM -------- DepthwiseConv2D --- Out
+ * [Feature] / [Feature]
+ * /
+ * KER -------
+ * [DWFilter]
+ *
+ *
+ * <AFTER>
+ * TFLConst (bias) ---------------------------
+ * \
+ * IFM ------ FeatureDecode ------------------ TFLDepthwiseConv2D --- FeatureEncode --- Out
+ * [Feature] [Tensor] / [Tensor] [Feature]
+ * /
+ * KER ------- DepthwiseFilterDecode --- TFLReshape
+ * [DWFilter] [Tensor / H W C M] [Tensor / 1 H W CM]
+ *
+ */
+class DepthwiseConv2DConverter : public CanonicalNodeConverter<loco::DepthwiseConv2D>
+{
+public:
+ const char *name(void) const final { return "exo::DepthwiseConv2DConverter"; }
+
+public:
+ bool convert(loco::DepthwiseConv2D *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_DEPTHWISECONV2D_CONVERTER__
diff --git a/compiler/exo/src/Conversion/EltwiseAddConverter.cpp b/compiler/exo/src/Conversion/EltwiseAddConverter.cpp
new file mode 100644
index 000000000..557f47944
--- /dev/null
+++ b/compiler/exo/src/Conversion/EltwiseAddConverter.cpp
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EltwiseAddConverter.h"
+
+#include "EltwiseBinaryConverter.h"
+
+namespace exo
+{
+
+bool EltwiseAddConverter::convert(loco::EltwiseAdd *origin)
+{
+ return EltwiseBinaryConvert<loco::EltwiseAdd, locoex::TFLAdd>(origin);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/EltwiseAddConverter.h b/compiler/exo/src/Conversion/EltwiseAddConverter.h
new file mode 100644
index 000000000..97e1071b5
--- /dev/null
+++ b/compiler/exo/src/Conversion/EltwiseAddConverter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_ELTWISEADD_CONVERTER_H__
+#define __CONVERSION_ELTWISEADD_CONVERTER_H__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::EltwiseAdd to TFLAdd
+ */
+class EltwiseAddConverter : public CanonicalNodeConverter<loco::EltwiseAdd>
+{
+public:
+ const char *name(void) const final { return "exo::EltwiseAddConverter"; }
+
+public:
+ bool convert(loco::EltwiseAdd *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_ELTWISEADD_CONVERTER_H__
diff --git a/compiler/exo/src/Conversion/EltwiseBinaryConverter.h b/compiler/exo/src/Conversion/EltwiseBinaryConverter.h
new file mode 100644
index 000000000..095da9e5c
--- /dev/null
+++ b/compiler/exo/src/Conversion/EltwiseBinaryConverter.h
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_ELTWISEBINARY_CONVERTER_H__
+#define __CONVERSION_ELTWISEBINARY_CONVERTER_H__
+
+#include "GraphBlock.h"
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include <loco/IR/Nodes.h>
+
+#include <loco/Service/ShapeInference.h>
+
+namespace
+{
+
+template <class ELTWISEBIN, class TFLBIN>
+class EltwiseBinInputHandler : public exo::InputHandler<ELTWISEBIN, TFLBIN>
+{
+public:
+ void handover(ELTWISEBIN *origin, TFLBIN *replacer) override
+ {
+ assert(origin && replacer);
+ replacer->x(origin->lhs());
+ replacer->y(origin->rhs());
+ }
+
+ std::vector<loco::Node *> getInputsToConvert(ELTWISEBIN *origin) override
+ {
+ assert(origin);
+ std::vector<loco::Node *> inputs({origin->lhs(), origin->rhs()});
+ return inputs;
+ }
+
+ void set(TFLBIN *replacer, std::vector<loco::Node *> &to) override
+ {
+ assert(to.size() == 2);
+
+ replacer->x(to.at(0));
+ replacer->y(to.at(1));
+ }
+
+ void nullify(ELTWISEBIN *origin) override
+ {
+ assert(origin);
+ origin->lhs(nullptr);
+ origin->rhs(nullptr);
+ }
+};
+
+template <class TFLBIN> void init_fused_act_func(TFLBIN *);
+
+template <> inline void init_fused_act_func(locoex::TFLAdd *node)
+{
+ node->fusedActivationFunction(locoex::FusedActFunc::NONE);
+}
+
+template <> inline void init_fused_act_func(locoex::TFLMul *node)
+{
+ node->fusedActivationFunction(locoex::FusedActFunc::NONE);
+}
+
+template <> inline void init_fused_act_func(locoex::TFLSub *node)
+{
+ node->fusedActivationFunction(locoex::FusedActFunc::NONE);
+}
+
+template <> inline void init_fused_act_func(locoex::TFLDiv *node)
+{
+ node->fusedActivationFunction(locoex::FusedActFunc::NONE);
+}
+
+} // namespace
+
+namespace exo
+{
+
+template <class ELTWISEBIN, class TFLBIN> bool EltwiseBinaryConvert(ELTWISEBIN *origin)
+{
+ EltwiseBinInputHandler<ELTWISEBIN, TFLBIN> input_handler;
+ exo::DomainConverter<ELTWISEBIN, TFLBIN> domain_converter;
+
+ auto tfl_node = domain_converter.template convert<FeatureLayout::NHWC>(origin, input_handler);
+
+ if (tfl_node == nullptr)
+ return false;
+
+ init_fused_act_func(tfl_node);
+
+ return true;
+}
+
+} // namespace exo
+
+#endif // __CONVERSION_ELTWISEBINARY_CONVERTER_H__
diff --git a/compiler/exo/src/Conversion/EltwiseDivConverter.cpp b/compiler/exo/src/Conversion/EltwiseDivConverter.cpp
new file mode 100644
index 000000000..dc8eae461
--- /dev/null
+++ b/compiler/exo/src/Conversion/EltwiseDivConverter.cpp
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EltwiseDivConverter.h"
+
+#include "EltwiseBinaryConverter.h"
+
+namespace exo
+{
+
+bool EltwiseDivConverter::convert(loco::EltwiseDiv *origin)
+{
+ return EltwiseBinaryConvert<loco::EltwiseDiv, locoex::TFLDiv>(origin);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/EltwiseDivConverter.h b/compiler/exo/src/Conversion/EltwiseDivConverter.h
new file mode 100644
index 000000000..06b2d685b
--- /dev/null
+++ b/compiler/exo/src/Conversion/EltwiseDivConverter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_ELTWISEDIV_CONVERTER_H__
+#define __CONVERSION_ELTWISEDIV_CONVERTER_H__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::EltwiseDiv to TFLDiv
+ */
+class EltwiseDivConverter : public CanonicalNodeConverter<loco::EltwiseDiv>
+{
+public:
+ const char *name(void) const final { return "exo::EltwiseDivConverter"; }
+
+public:
+ bool convert(loco::EltwiseDiv *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_ELTWISEDIV_CONVERTER_H__
diff --git a/compiler/exo/src/Conversion/EltwiseMaxConverter.cpp b/compiler/exo/src/Conversion/EltwiseMaxConverter.cpp
new file mode 100644
index 000000000..dd7d34440
--- /dev/null
+++ b/compiler/exo/src/Conversion/EltwiseMaxConverter.cpp
@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EltwiseMaxConverter.h"
+
+#include "GraphBlock.h"
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include <loco/Service/ShapeInference.h>
+
+namespace
+{
+
+class EltwiseMaxInputHandler : public exo::InputHandler<loco::EltwiseMax, locoex::TFLMaximum>
+{
+public:
+ void handover(loco::EltwiseMax *origin, locoex::TFLMaximum *replacer) override
+ {
+ replacer->x(origin->lhs());
+ replacer->y(origin->rhs());
+ }
+
+ std::vector<loco::Node *> getInputsToConvert(loco::EltwiseMax *origin) override
+ {
+ std::vector<loco::Node *> inputs({origin->lhs(), origin->rhs()});
+ return inputs;
+ }
+
+ void set(locoex::TFLMaximum *replacer, std::vector<loco::Node *> &to) override
+ {
+ assert(to.size() == 2);
+
+ replacer->x(to.at(0));
+ replacer->y(to.at(1));
+ }
+
+ void nullify(loco::EltwiseMax *origin) override
+ {
+ assert(origin);
+ origin->lhs(nullptr);
+ origin->rhs(nullptr);
+ }
+};
+
+} // namespace
+
+namespace exo
+{
+
+bool EltwiseMaxConverter::convert(loco::EltwiseMax *origin)
+{
+ EltwiseMaxInputHandler input_handler;
+ exo::DomainConverter<loco::EltwiseMax, locoex::TFLMaximum> domain_converter;
+
+ auto tfl_new = domain_converter.convert<FeatureLayout::NHWC>(origin, input_handler);
+
+ return (tfl_new != nullptr);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/EltwiseMaxConverter.h b/compiler/exo/src/Conversion/EltwiseMaxConverter.h
new file mode 100644
index 000000000..708745419
--- /dev/null
+++ b/compiler/exo/src/Conversion/EltwiseMaxConverter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_ELTWISEMAX_CONVERTER_H__
+#define __CONVERSION_ELTWISEMAX_CONVERTER_H__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::EltwiseMax to TFLMaximum
+ */
+class EltwiseMaxConverter : public CanonicalNodeConverter<loco::EltwiseMax>
+{
+public:
+ const char *name(void) const final { return "exo::EltwiseMaxConverter"; }
+
+public:
+ bool convert(loco::EltwiseMax *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_ELTWISEMAX_CONVERTER_H__
diff --git a/compiler/exo/src/Conversion/EltwiseMulConverter.cpp b/compiler/exo/src/Conversion/EltwiseMulConverter.cpp
new file mode 100644
index 000000000..f7a4b8298
--- /dev/null
+++ b/compiler/exo/src/Conversion/EltwiseMulConverter.cpp
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EltwiseMulConverter.h"
+
+#include "EltwiseBinaryConverter.h"
+
+namespace exo
+{
+
+bool EltwiseMulConverter::convert(loco::EltwiseMul *origin)
+{
+ return EltwiseBinaryConvert<loco::EltwiseMul, locoex::TFLMul>(origin);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/EltwiseMulConverter.h b/compiler/exo/src/Conversion/EltwiseMulConverter.h
new file mode 100644
index 000000000..4f73484c0
--- /dev/null
+++ b/compiler/exo/src/Conversion/EltwiseMulConverter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_ELTWISEMUL_CONVERTER_H__
+#define __CONVERSION_ELTWISEMUL_CONVERTER_H__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::EltwiseMul to TFLMul
+ */
+class EltwiseMulConverter : public CanonicalNodeConverter<loco::EltwiseMul>
+{
+public:
+ const char *name(void) const final { return "exo::EltwiseMulConverter"; }
+
+public:
+ bool convert(loco::EltwiseMul *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_ELTWISEMUL_CONVERTER_H__
diff --git a/compiler/exo/src/Conversion/EltwiseSqrtConverter.cpp b/compiler/exo/src/Conversion/EltwiseSqrtConverter.cpp
new file mode 100644
index 000000000..6dead7dc6
--- /dev/null
+++ b/compiler/exo/src/Conversion/EltwiseSqrtConverter.cpp
@@ -0,0 +1,68 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EltwiseSqrtConverter.h"
+
+#include "GraphBlock.h"
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include <loco/Service/ShapeInference.h>
+
+namespace
+{
+
+class EltwiseSqrtInputHandler : public exo::InputHandler<loco::EltwiseSqrt, locoex::TFLSqrt>
+{
+public:
+ void handover(loco::EltwiseSqrt *origin, locoex::TFLSqrt *replacer) override
+ {
+ replacer->x(origin->input());
+ }
+
+ std::vector<loco::Node *> getInputsToConvert(loco::EltwiseSqrt *origin) override
+ {
+ std::vector<loco::Node *> inputs({origin->input()});
+ return inputs;
+ }
+
+ void set(locoex::TFLSqrt *replacer, std::vector<loco::Node *> &to) override
+ {
+ assert(to.size() == 1);
+
+ replacer->x(to.at(0));
+ }
+
+ void nullify(loco::EltwiseSqrt *origin) override { origin->input(nullptr); }
+};
+
+} // namespace
+
+namespace exo
+{
+
+bool EltwiseSqrtConverter::convert(loco::EltwiseSqrt *origin)
+{
+ EltwiseSqrtInputHandler input_handler;
+ exo::DomainConverter<loco::EltwiseSqrt, locoex::TFLSqrt> domain_converter;
+
+ auto tfl_new = domain_converter.convert<FeatureLayout::NHWC>(origin, input_handler);
+
+ return (tfl_new != nullptr);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/EltwiseSqrtConverter.h b/compiler/exo/src/Conversion/EltwiseSqrtConverter.h
new file mode 100644
index 000000000..5ee3185ff
--- /dev/null
+++ b/compiler/exo/src/Conversion/EltwiseSqrtConverter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ELTWISE_SQRT_CONVERTER_H__
+#define __ELTWISE_SQRT_CONVERTER_H__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::EltwiseSqrt to TFLSqrt
+ */
+class EltwiseSqrtConverter : public CanonicalNodeConverter<loco::EltwiseSqrt>
+{
+public:
+ const char *name(void) const final { return "exo::EltwiseSqrtConverter"; }
+
+public:
+ bool convert(loco::EltwiseSqrt *origin) final;
+};
+
+} // namespace exo
+
+#endif // __ELTWISE_SQRT_CONVERTER_H__
diff --git a/compiler/exo/src/Conversion/EltwiseSubConverter.cpp b/compiler/exo/src/Conversion/EltwiseSubConverter.cpp
new file mode 100644
index 000000000..5647c47a2
--- /dev/null
+++ b/compiler/exo/src/Conversion/EltwiseSubConverter.cpp
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EltwiseSubConverter.h"
+
+#include "EltwiseBinaryConverter.h"
+
+namespace exo
+{
+
+bool EltwiseSubConverter::convert(loco::EltwiseSub *origin)
+{
+ return EltwiseBinaryConvert<loco::EltwiseSub, locoex::TFLSub>(origin);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/EltwiseSubConverter.h b/compiler/exo/src/Conversion/EltwiseSubConverter.h
new file mode 100644
index 000000000..d61b76ec0
--- /dev/null
+++ b/compiler/exo/src/Conversion/EltwiseSubConverter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_ELTWISESUB_CONVERTER_H__
+#define __CONVERSION_ELTWISESUB_CONVERTER_H__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::EltwiseSub to TFLSub
+ */
+class EltwiseSubConverter : public CanonicalNodeConverter<loco::EltwiseSub>
+{
+public:
+ const char *name(void) const final { return "exo::EltwiseSubConverter"; }
+
+public:
+ bool convert(loco::EltwiseSub *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_ELTWISESUB_CONVERTER_H__
diff --git a/compiler/exo/src/Conversion/FeatureBiasAddConverter.cpp b/compiler/exo/src/Conversion/FeatureBiasAddConverter.cpp
new file mode 100644
index 000000000..b9aaf140b
--- /dev/null
+++ b/compiler/exo/src/Conversion/FeatureBiasAddConverter.cpp
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FeatureBiasAddConverter.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include "GraphBlock.h"
+
+#include <loco.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <cassert>
+
+namespace
+{
+
+inline void init_fused_act_func(locoex::TFLAdd *node)
+{
+ node->fusedActivationFunction(locoex::FusedActFunc::NONE);
+}
+
+} // namespace
+
+namespace exo
+{
+
+/**
+ * @brief Converts loco::FeatureBiasAdd to locoex::TFLAdd
+ *
+ * Before:
+ * Foo ---+
+ * |
+ * loco::FeatureBiasAdd - FeatureDecode - ...
+ * |
+ * Bar - BiasEncode --+
+ *
+ * After:
+ *
+ * Foo - loco::FeatureDecode --+ loco::FeatureBiasAdd
+ * |(x)
+ * TFLAdd -- loco::FeatureEncode - FeatureDecode - ...
+ * |(y)
+ * Bar - BiasEncode - loco::BiasDecode --+
+ */
+bool FeatureBiasAddConverter::convert(loco::FeatureBiasAdd *origin)
+{
+ auto *graph = origin->graph();
+
+ auto tfl_add = graph->nodes()->create<locoex::TFLAdd>();
+
+ // handling input x
+ assert(loco::shape_get(origin->value()).domain() == loco::Domain::Feature);
+
+ auto fea_dec = make_feature_decode<FeatureLayout::NHWC>(origin->value());
+ tfl_add->x(fea_dec);
+
+ // handling input y
+ auto bias_dec = graph->nodes()->create<loco::BiasDecode>();
+ assert(bias_dec != nullptr);
+
+ bias_dec->input(origin->bias());
+
+ tfl_add->y(bias_dec);
+
+ // fused activation function
+ init_fused_act_func(tfl_add);
+
+ // handling output
+ auto fea_enc = make_feature_encode<FeatureLayout::NHWC>(tfl_add);
+
+ loco::replace(origin).with(fea_enc);
+ origin->value(nullptr);
+
+ return true;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/FeatureBiasAddConverter.h b/compiler/exo/src/Conversion/FeatureBiasAddConverter.h
new file mode 100644
index 000000000..5c4f10213
--- /dev/null
+++ b/compiler/exo/src/Conversion/FeatureBiasAddConverter.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_FEATUREBIASADD_CONVERTER__
+#define __CONVERSION_FEATUREBIASADD_CONVERTER__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+class FeatureBiasAddConverter : public CanonicalNodeConverter<loco::FeatureBiasAdd>
+{
+public:
+ const char *name(void) const final { return "exo::TFLAddConverter"; }
+
+public:
+ bool convert(loco::FeatureBiasAdd *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_FEATUREBIASADD_CONVERTER__
diff --git a/compiler/exo/src/Conversion/FeatureBiasAddConverter.test.cpp b/compiler/exo/src/Conversion/FeatureBiasAddConverter.test.cpp
new file mode 100644
index 000000000..f3c4a5f81
--- /dev/null
+++ b/compiler/exo/src/Conversion/FeatureBiasAddConverter.test.cpp
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FeatureBiasAddConverter.h"
+
+#include "GraphBlock.h"
+#include "Dialect/IR/TFLNodes.h"
+
+#include "TestGraph.h"
+#include "TestHelper.h"
+
+#include <loco.h>
+
+#include <gtest/gtest.h>
+
+TEST(FeatureBiasAddConverterTest, basic_test)
+{
+ exo::test::ExampleGraph<exo::test::ExampleGraphType::FeatureBiasAdd> g;
+
+ { // attrib setting
+ // pull
+ g.pull->dtype(loco::DataType::FLOAT32);
+ g.pull->shape({1, 2, 2, 3});
+
+ // bias value
+ g.constgen->dtype(loco::DataType::FLOAT32);
+ g.constgen->shape({3});
+ g.constgen->size<loco::DataType::FLOAT32>(3);
+
+ g.constgen->at<loco::DataType::FLOAT32>(0) = 0.5;
+ g.constgen->at<loco::DataType::FLOAT32>(1) = 1;
+ g.constgen->at<loco::DataType::FLOAT32>(2) = 1.5;
+ }
+
+ EXO_TEST_ASSERT_NODE_COUNT({g.push}, 7); // sanity check
+
+ // let's convert!!
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FeatureBiasAddConverter>();
+
+ test_phase.run(g.graph());
+
+ /*
+ Expected:
+
+ Pull - FeatureEncoder - FeatureDecode - TFLAdd - FeatureEncode - FeatureDecode - Push
+ |
+ ConstGen - BiasEncode - BiasDecode ---+
+ */
+ }
+
+ // check surroundings
+ auto tfl_add = exo::test::find_first_node_bytype<locoex::TFLAdd>(g.graph());
+ {
+ ASSERT_TRUE(tfl_add != nullptr);
+
+ // input x and its pred
+ {
+ auto actual_fea_dec = dynamic_cast<loco::FeatureDecode *>(tfl_add->x());
+ ASSERT_TRUE(actual_fea_dec != nullptr);
+
+ auto actual_fea_enc = dynamic_cast<loco::FeatureEncode *>(actual_fea_dec->input());
+ ASSERT_TRUE(actual_fea_enc != nullptr);
+ ASSERT_TRUE(actual_fea_enc == g.fea_enc);
+ }
+
+ // input y and its pred
+ {
+ auto actual_bias_dec = dynamic_cast<loco::BiasDecode *>(tfl_add->y());
+ ASSERT_TRUE(actual_bias_dec != nullptr);
+
+ auto actual_bias_enc = dynamic_cast<loco::BiasEncode *>(actual_bias_dec->input());
+ ASSERT_TRUE(actual_bias_enc != nullptr);
+ ASSERT_TRUE(actual_bias_enc == g.bias_enc);
+ }
+
+ // output check
+ {
+ auto actual_fea_enc = exo::test::get_only_succ<loco::FeatureEncode>(tfl_add);
+ ASSERT_TRUE(actual_fea_enc != nullptr);
+
+ auto actual_fea_dec = exo::test::get_only_succ<loco::FeatureDecode>(actual_fea_enc);
+ ASSERT_TRUE(actual_fea_dec != nullptr);
+ ASSERT_TRUE(actual_fea_dec == g.fea_dec);
+ }
+ }
+}
diff --git a/compiler/exo/src/Conversion/MatMulConverter.cpp b/compiler/exo/src/Conversion/MatMulConverter.cpp
new file mode 100644
index 000000000..b1158b73d
--- /dev/null
+++ b/compiler/exo/src/Conversion/MatMulConverter.cpp
@@ -0,0 +1,103 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MatMulConverter.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include "GraphBlock.h"
+#include "Check.h"
+
+#include <loco.h>
+#include <loco/Service/TypeInference.h>
+#include <loco/Service/ShapeInference.h>
+
+namespace exo
+{
+/**
+ * @brief Converts loco::MatMul to locoex::TFLFullyConnected
+ * @note Because TFLFullyConnected accepts input and weights of loco::Domain::Matrix,
+ * loco::MatrixDecode will be inserted as an input and weights
+ * to meet domain invariant.
+ *
+ * How it works:
+ *
+ * Before:
+ * Foo1 ---- MatrixEncode ---- MatMul ---- MatrixDecode ---- Bar
+ * Foo2 ---- MatrixEncode ----/
+ *
+ * After:
+ *
+ * Foo1 - MatrixEncode - MatrixDecode - TFLFullyConnected - MatrixEncode - MatrixDecode - Bar
+ * Foo2 - MatrixEncode - MatrixDecode -/
+ *
+ * @note This method replaces MatMul with "- MatrixDecode - TFLFullyConnected - MatrixEncode -".
+ * - MatrixDecode -/
+ * Redundant nodes will be removed during transforms.
+ *
+ * @ref
+ * https://github.com/tensorflow/tensorflow/blob/v1.13.1/tensorflow/lite/kernels/internal/reference/fully_connected.h
+ */
+bool MatMulConverter::convert(loco::MatMul *origin)
+{
+ auto *graph = origin->graph();
+
+ assert(origin->lhs());
+ assert(origin->rhs());
+
+ auto tfl_fc = graph->nodes()->create<locoex::TFLFullyConnected>();
+ tfl_fc->fusedActivationFunction(locoex::FusedActFunc::NONE);
+
+ // let's create a new graph connection with tfl_fc
+ {
+ // input
+ auto lhs_matrix_dec = make_matrix_decode<MatrixLayout::HW>(origin->lhs());
+ tfl_fc->input(lhs_matrix_dec);
+
+ // weights (WH format on TFLite)
+ auto rhs_matrix_dec = make_matrix_decode<MatrixLayout::WH>(origin->rhs());
+ tfl_fc->weights(rhs_matrix_dec);
+
+ // bias
+ auto zero_const = graph->nodes()->create<locoex::TFLConst>();
+ { // TODO Create optimization pass which fuse additional Add into bias of Conv or FC
+ assert(loco::shape_known(origin));
+ assert(loco::dtype_known(origin) && loco::dtype_get(origin) == loco::DataType::FLOAT32);
+
+ auto output_depth = loco::shape_get(origin->rhs()).as<loco::MatrixShape>().width();
+ // TODO Fix it with type inference
+ zero_const->dtype(loco::DataType::FLOAT32);
+ zero_const->rank(1);
+ zero_const->dim(0) = output_depth;
+ zero_const->size<loco::DataType::FLOAT32>(output_depth.value());
+ for (uint32_t x = 0; x < output_depth.value(); x++)
+ zero_const->at<loco::DataType::FLOAT32>(x) = 0.0;
+ }
+ tfl_fc->bias(zero_const);
+
+ // output
+ auto matrix_enc = make_matrix_encode<MatrixLayout::HW>(tfl_fc);
+
+ // replace canonical node
+ loco::replace(origin).with(matrix_enc);
+ origin->lhs(nullptr);
+ origin->rhs(nullptr);
+ }
+
+ return true;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/MatMulConverter.h b/compiler/exo/src/Conversion/MatMulConverter.h
new file mode 100644
index 000000000..e64c4a0f2
--- /dev/null
+++ b/compiler/exo/src/Conversion/MatMulConverter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_FULLY_CONNECTED_CONVERTER__
+#define __CONVERSION_FULLY_CONNECTED_CONVERTER__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::MatMul to locoex::TFLFullyConnected
+ */
+class MatMulConverter : public CanonicalNodeConverter<loco::MatMul>
+{
+public:
+ const char *name(void) const final { return "exo::MatMulConverter"; }
+
+public:
+ bool convert(loco::MatMul *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_FULLY_CONNECTED_CONVERTER__
diff --git a/compiler/exo/src/Conversion/MaxPool2DConverter.cpp b/compiler/exo/src/Conversion/MaxPool2DConverter.cpp
new file mode 100644
index 000000000..67e5ab833
--- /dev/null
+++ b/compiler/exo/src/Conversion/MaxPool2DConverter.cpp
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MaxPool2DConverter.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "GraphBlock.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Converts loco::MaxPool2D to locoex::TFLMaxPool2D
+ *
+ * @note This works similar to AvgPool2DConverter. Please refer to the comment in
+ * AvgPool2DConverter.
+ */
+bool MaxPool2DConverter::convert(loco::MaxPool2D *origin)
+{
+ auto *graph = origin->graph();
+
+ auto dec = make_feature_decode<FeatureLayout::NHWC>(origin->ifm());
+ auto tfl_max = graph->nodes()->create<locoex::TFLMaxPool2D>();
+ {
+ tfl_max->value(dec);
+
+ // set attributes
+ tfl_max->stride()->w(origin->stride()->horizontal());
+ tfl_max->stride()->h(origin->stride()->vertical());
+
+ tfl_max->filter()->w(origin->window()->horizontal());
+ tfl_max->filter()->h(origin->window()->vertical());
+
+ auto pad = origin->pad();
+ if (pad->bottom() == 0 && pad->top() == 0 && pad->left() == 0 && pad->right() == 0)
+ tfl_max->padding(locoex::Padding::VALID);
+ else
+ // TODO This is necessary, but not sufficient condition. More rigorous check required
+ tfl_max->padding(locoex::Padding::SAME);
+
+ tfl_max->fusedActivationFunction(locoex::FusedActFunc::NONE);
+ }
+
+ auto enc = make_feature_encode<FeatureLayout::NHWC>(tfl_max);
+
+ loco::replace(origin).with(enc);
+ origin->ifm(nullptr);
+
+ return true;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/MaxPool2DConverter.h b/compiler/exo/src/Conversion/MaxPool2DConverter.h
new file mode 100644
index 000000000..3f526d88f
--- /dev/null
+++ b/compiler/exo/src/Conversion/MaxPool2DConverter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_MAXPOOL2D_CONVERTER__
+#define __CONVERSION_MAXPOOL2D_CONVERTER__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::MaxPool2D to locoex::TFLMaxPool2D
+ */
+class MaxPool2DConverter : public CanonicalNodeConverter<loco::MaxPool2D>
+{
+public:
+ const char *name(void) const final { return "exo::MaxPool2DConverter"; }
+
+public:
+ bool convert(loco::MaxPool2D *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_MAXPOOL2D_CONVERTER__
diff --git a/compiler/exo/src/Conversion/Relu6Converter.cpp b/compiler/exo/src/Conversion/Relu6Converter.cpp
new file mode 100644
index 000000000..b694511f5
--- /dev/null
+++ b/compiler/exo/src/Conversion/Relu6Converter.cpp
@@ -0,0 +1,68 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Relu6Converter.h"
+
+#include "GraphBlock.h"
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include <loco/Service/ShapeInference.h>
+
+namespace
+{
+
+class Relu6InputHandler : public exo::InputHandler<loco::ReLU6, locoex::TFLRelu6>
+{
+public:
+ void handover(loco::ReLU6 *origin, locoex::TFLRelu6 *replacer) override
+ {
+ replacer->features(origin->input());
+ }
+
+ std::vector<loco::Node *> getInputsToConvert(loco::ReLU6 *origin) override
+ {
+ std::vector<loco::Node *> inputs({origin->input()});
+ return inputs;
+ }
+
+ void set(locoex::TFLRelu6 *replacer, std::vector<loco::Node *> &to) override
+ {
+ assert(to.size() == 1);
+
+ replacer->features(to.at(0));
+ }
+
+ void nullify(loco::ReLU6 *origin) override { origin->input(nullptr); }
+};
+
+} // namespace
+
+namespace exo
+{
+
+bool Relu6Converter::convert(loco::ReLU6 *origin)
+{
+ Relu6InputHandler input_handler;
+ exo::DomainConverter<loco::ReLU6, locoex::TFLRelu6> domain_converter;
+
+ auto tfl_node = domain_converter.convert<FeatureLayout::NHWC>(origin, input_handler);
+
+ return (tfl_node != nullptr);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/Relu6Converter.h b/compiler/exo/src/Conversion/Relu6Converter.h
new file mode 100644
index 000000000..d987b42d0
--- /dev/null
+++ b/compiler/exo/src/Conversion/Relu6Converter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_RELU6_CONVERTER_H__
+#define __CONVERSION_RELU6_CONVERTER_H__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::Relu6 to TFLRelu6
+ */
+class Relu6Converter : public CanonicalNodeConverter<loco::ReLU6>
+{
+public:
+ const char *name(void) const final { return "exo::Relu6Converter"; }
+
+public:
+ bool convert(loco::ReLU6 *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_RELU6_CONVERTER_H__
diff --git a/compiler/exo/src/Conversion/ReluConverter.cpp b/compiler/exo/src/Conversion/ReluConverter.cpp
new file mode 100644
index 000000000..92adef94d
--- /dev/null
+++ b/compiler/exo/src/Conversion/ReluConverter.cpp
@@ -0,0 +1,68 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ReluConverter.h"
+
+#include "GraphBlock.h"
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include <loco/Service/ShapeInference.h>
+
+namespace
+{
+
+class ReluInputHandler : public exo::InputHandler<loco::ReLU, locoex::TFLRelu>
+{
+public:
+ void handover(loco::ReLU *origin, locoex::TFLRelu *replacer) override
+ {
+ replacer->features(origin->input());
+ }
+
+ std::vector<loco::Node *> getInputsToConvert(loco::ReLU *origin) override
+ {
+ std::vector<loco::Node *> inputs({origin->input()});
+ return inputs;
+ }
+
+ void set(locoex::TFLRelu *replacer, std::vector<loco::Node *> &to) override
+ {
+ assert(to.size() == 1);
+
+ replacer->features(to.at(0));
+ }
+
+ void nullify(loco::ReLU *origin) override { origin->input(nullptr); }
+};
+
+} // namespace
+
+namespace exo
+{
+
+bool ReluConverter::convert(loco::ReLU *origin)
+{
+ ReluInputHandler input_handler;
+ exo::DomainConverter<loco::ReLU, locoex::TFLRelu> domain_converter;
+
+ auto tfl_node = domain_converter.convert<FeatureLayout::NHWC>(origin, input_handler);
+
+ return (tfl_node != nullptr);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/ReluConverter.h b/compiler/exo/src/Conversion/ReluConverter.h
new file mode 100644
index 000000000..e1e82ae4b
--- /dev/null
+++ b/compiler/exo/src/Conversion/ReluConverter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_RELU_CONVERTER_H__
+#define __CONVERSION_RELU_CONVERTER_H__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::Relu to TFLRelu
+ */
+class ReluConverter : public CanonicalNodeConverter<loco::ReLU>
+{
+public:
+ const char *name(void) const final { return "exo::ReluConverter"; }
+
+public:
+ bool convert(loco::ReLU *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_RELU_CONVERTER_H__
diff --git a/compiler/exo/src/Conversion/ReluConverter.test.cpp b/compiler/exo/src/Conversion/ReluConverter.test.cpp
new file mode 100644
index 000000000..f53d656b4
--- /dev/null
+++ b/compiler/exo/src/Conversion/ReluConverter.test.cpp
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ReluConverter.h"
+
+#include "GraphBlock.h"
+#include "Dialect/IR/TFLNodes.h"
+
+#include "TestHelper.h"
+#include "TestGraph.h"
+
+#include <gtest/gtest.h>
+
+TEST(ReluConverterTest, relu_tensor_inout)
+{
+ exo::test::TestGraph graph;
+ {
+ auto tanh = graph.append<loco::Tanh>(graph.pull);
+ auto relu = graph.append<loco::ReLU>(tanh);
+ auto relu6 = graph.append<loco::ReLU6>(relu);
+ graph.complete();
+
+ auto pull = graph.pull;
+ {
+ pull->dtype(loco::DataType::FLOAT32);
+ pull->shape({2, 2});
+ }
+ }
+
+ // let's convert
+ exo::test::TypeShapeReadyPhase test_phase;
+ {
+ test_phase.add_pass<exo::ReluConverter>();
+ test_phase.run(graph.g.get());
+ }
+
+ loco::Node *node = exo::test::find_first_node_bytype<loco::Tanh>(graph.g.get());
+ ASSERT_TRUE(node != nullptr);
+ node = exo::test::get_only_succ<locoex::TFLRelu>(node);
+ ASSERT_TRUE(node != nullptr);
+ node = exo::test::get_only_succ<loco::ReLU6>(node);
+ ASSERT_TRUE(node != nullptr);
+}
+
+TEST(ReluConverterTest, relu_feature_inout)
+{
+ // g = Pull - FeatureEncode - Relu - FeatureDecode - Push
+ exo::test::TestGraph graph;
+ {
+ auto enc = exo::make_feature_encode<exo::FeatureLayout::NHWC>(graph.pull);
+ auto relu = graph.append<loco::ReLU>(enc);
+ auto dec = exo::make_feature_decode<exo::FeatureLayout::NHWC>(relu);
+ graph.complete(dec);
+ }
+
+ auto pull = graph.pull;
+ {
+ pull->dtype(loco::DataType::FLOAT32);
+ pull->shape({1, 2, 3, 4});
+ }
+
+ exo::test::TypeShapeReadyPhase test_phase;
+ {
+ test_phase.add_pass<exo::ReluConverter>();
+ test_phase.run(graph.g.get());
+ }
+
+ // now, g = Pull - FeatureEncode - FeatureDecode - TFLRelu - FeatureEncode - FeatureDecode - Push
+
+ // Check
+ EXO_TEST_ASSERT_NODE_COUNT({graph.push}, 7);
+
+ // Check [FeatureEncode - FeatureDecode - TFLRelu - FeatureEncode - FeatureDecode] chunk
+ loco::Node *node = exo::test::find_first_node_bytype<loco::FeatureEncode>(graph.g.get());
+ ASSERT_TRUE(node != nullptr);
+ node = exo::test::get_only_succ<loco::FeatureDecode>(node);
+ ASSERT_TRUE(node != nullptr);
+ node = exo::test::get_only_succ<locoex::TFLRelu>(node);
+ ASSERT_TRUE(node != nullptr);
+ node = exo::test::get_only_succ<loco::FeatureEncode>(node);
+ ASSERT_TRUE(node != nullptr);
+ node = exo::test::get_only_succ<loco::FeatureDecode>(node);
+ ASSERT_TRUE(node != nullptr);
+}
diff --git a/compiler/exo/src/Conversion/TensorBroadcastConverter.cpp b/compiler/exo/src/Conversion/TensorBroadcastConverter.cpp
new file mode 100644
index 000000000..532332742
--- /dev/null
+++ b/compiler/exo/src/Conversion/TensorBroadcastConverter.cpp
@@ -0,0 +1,189 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TensorBroadcastConverter.h"
+
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+#include <loco.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/IR/CanonicalNode.h>
+
+#include <set>
+
+namespace
+{
+
+template <class T> loco::TensorBroadcast *input_as_tbc(T *node)
+{
+ loco::TensorBroadcast *tbc = dynamic_cast<loco::TensorBroadcast *>(node->x());
+ if (tbc == nullptr)
+ tbc = dynamic_cast<loco::TensorBroadcast *>(node->y());
+
+ return tbc;
+}
+
+struct Collector final : public locoex::TFLNodeMutableVisitor<void>
+{
+ using NodePair = std::pair<loco::TensorBroadcast *, loco::Node *>;
+
+ void visit(locoex::TFLAdd *node) final
+ {
+ if (auto tbc = input_as_tbc<locoex::TFLAdd>(node))
+ {
+ NodePair pair(tbc, node);
+ candidates.insert(pair);
+ }
+ }
+
+ void visit(locoex::TFLDiv *node) final
+ {
+ if (auto tbc = input_as_tbc<locoex::TFLDiv>(node))
+ {
+ NodePair pair(tbc, node);
+ candidates.insert(pair);
+ }
+ }
+
+ void visit(locoex::TFLMul *node) final
+ {
+ if (auto tbc = input_as_tbc<locoex::TFLMul>(node))
+ {
+ NodePair pair(tbc, node);
+ candidates.insert(pair);
+ }
+ }
+
+ void visit(locoex::TFLSub *node) final
+ {
+ if (auto tbc = input_as_tbc<locoex::TFLSub>(node))
+ {
+ NodePair pair(tbc, node);
+ candidates.insert(pair);
+ }
+ }
+
+ void visit(locoex::TFLMaximum *node) final
+ {
+ if (auto tbc = input_as_tbc<locoex::TFLMaximum>(node))
+ {
+ NodePair pair(tbc, node);
+ candidates.insert(pair);
+ }
+ }
+
+ void visit(locoex::TFLNode *) final { return; }
+
+ std::set<NodePair> candidates;
+};
+
+bool mapping_condition(Collector::NodePair &)
+{
+ // TODO fill condition
+
+ return true;
+}
+
+template <class T> void jump_connection(loco::TensorBroadcast *tbc, T *tflnode)
+{
+ if (tflnode->x() == tbc)
+ tflnode->x(tbc->input());
+ else if (tflnode->y() == tbc)
+ tflnode->y(tbc->input());
+ else
+ assert(false);
+
+ tbc->input(nullptr);
+}
+
+} // namespace
+
+namespace exo
+{
+
+/**
+ * @brief Disconnects loco::TensorBroadcast from the graph if following node
+ * is one of binary node: TFLAdd, TFLSub, TFLMul, TFLDiv, TFLMaximum
+ * and meets condition (TBA)
+ * @note
+ * Before:
+ * x --- TensorBroadcast --- TFLXXX --- output
+ * y ----------------------/
+ *
+ * After:
+ * --- TensorBroadcast ---
+ * x --- TFLXXX --- output
+ * y --/
+ */
+bool TensorBroadcastConverter::run(loco::Graph *graph)
+{
+ Collector collector;
+
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == locoex::TFLDialect::get())
+ {
+ auto tfl_node = dynamic_cast<locoex::TFLNode *>(node);
+ tfl_node->accept(&collector);
+ }
+ }
+
+ bool changed = false;
+
+ for (auto pair : collector.candidates)
+ {
+ if (mapping_condition(pair))
+ {
+ loco::TensorBroadcast *tensorbroadcast = pair.first;
+ if (auto tfladd = dynamic_cast<locoex::TFLAdd *>(pair.second))
+ {
+ jump_connection<locoex::TFLAdd>(tensorbroadcast, tfladd);
+ changed = true;
+ }
+ else if (auto tfldiv = dynamic_cast<locoex::TFLDiv *>(pair.second))
+ {
+ jump_connection<locoex::TFLDiv>(tensorbroadcast, tfldiv);
+ changed = true;
+ }
+ else if (auto tflmul = dynamic_cast<locoex::TFLMul *>(pair.second))
+ {
+ jump_connection<locoex::TFLMul>(tensorbroadcast, tflmul);
+ changed = true;
+ }
+ else if (auto tflsub = dynamic_cast<locoex::TFLSub *>(pair.second))
+ {
+ jump_connection<locoex::TFLSub>(tensorbroadcast, tflsub);
+ changed = true;
+ }
+ else if (auto tflmaximum = dynamic_cast<locoex::TFLMaximum *>(pair.second))
+ {
+ jump_connection<locoex::TFLMaximum>(tensorbroadcast, tflmaximum);
+ changed = true;
+ }
+ else
+ {
+ assert(false);
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/TensorBroadcastConverter.h b/compiler/exo/src/Conversion/TensorBroadcastConverter.h
new file mode 100644
index 000000000..3cf79b0ba
--- /dev/null
+++ b/compiler/exo/src/Conversion/TensorBroadcastConverter.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TENSOR_BROADCAST_CONVERTER_H__
+#define __TENSOR_BROADCAST_CONVERTER_H__
+
+#include <loco.h>
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Pass to resolve TensorBroadcast IR
+ */
+class TensorBroadcastConverter : public logo::Pass
+{
+public:
+ virtual const char *name(void) const { return "exo::TensorBroadcastConverter"; }
+
+public:
+ bool run(loco::Graph *graph);
+};
+
+} // namespace exo
+
+#endif //__TENSOR_BROADCAST_CONVERTER_H__
diff --git a/compiler/exo/src/Conversion/TensorConcatConverter.cpp b/compiler/exo/src/Conversion/TensorConcatConverter.cpp
new file mode 100644
index 000000000..1c36b11f8
--- /dev/null
+++ b/compiler/exo/src/Conversion/TensorConcatConverter.cpp
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TensorConcatConverter.h"
+
+#include "GraphBlock.h"
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include <loco/Service/ShapeInference.h>
+
+namespace exo
+{
+/**
+ * @brief Converts loco::TensorConcat to locoex::TFLConcatenate
+ *
+ * Before:
+ * input:0 ----- loco::TensorConcat ------- C
+ * input:1 ----/
+ *
+ * After:
+ * input:0 ----- locoex::TFLConcatenate --- C
+ * input:1 ----/
+ *
+ * input:0 ----- loco::TensorConcat ---
+ * input:1 ----/
+ *
+ */
+bool TensorConcatConverter::convert(loco::TensorConcat *origin)
+{
+ assert(loco::shape_get(origin).domain() == loco::Domain::Tensor);
+
+ if (!loco::shape_known(origin))
+ {
+ return false;
+ }
+
+ auto tfl_concat = origin->graph()->nodes()->create<locoex::TFLConcatenation>(2);
+ tfl_concat->values(0, origin->lhs());
+ tfl_concat->values(1, origin->rhs());
+ tfl_concat->axis(origin->axis());
+ tfl_concat->fusedActivationFunction(locoex::FusedActFunc::NONE);
+
+ loco::replace(origin).with(tfl_concat);
+
+ origin->lhs(nullptr);
+ origin->rhs(nullptr);
+
+ return true;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/TensorConcatConverter.h b/compiler/exo/src/Conversion/TensorConcatConverter.h
new file mode 100644
index 000000000..6b90f4731
--- /dev/null
+++ b/compiler/exo/src/Conversion/TensorConcatConverter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_TENSORCONCAT_CONVERTER_H__
+#define __CONVERSION_TENSORCONCAT_CONVERTER_H__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::TensorConcat to TFLConcatenate
+ */
+class TensorConcatConverter : public CanonicalNodeConverter<loco::TensorConcat>
+{
+public:
+ const char *name(void) const final { return "exo::TensorConcatConverter"; }
+
+public:
+ bool convert(loco::TensorConcat *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_TENSORCONCAT_CONVERTER_H__
diff --git a/compiler/exo/src/Conversion/TensorReduceConverter.cpp b/compiler/exo/src/Conversion/TensorReduceConverter.cpp
new file mode 100644
index 000000000..8fcb1682d
--- /dev/null
+++ b/compiler/exo/src/Conversion/TensorReduceConverter.cpp
@@ -0,0 +1,95 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TensorReduceConverter.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Check.h"
+
+#include <oops/InternalExn.h>
+
+#include <loco.h>
+#include <loco/Service/ShapeInference.h>
+
+namespace
+{
+
+/**
+ * @brief Convert given TensorReduce as TFLMean
+ *
+ * <Before>
+ * In --- loco::TensorReduce --- Out(s)
+ *
+ * <After>
+ * In -------- locoex::TFLMean --- Out(s)
+ * /
+ * TFLConst ---
+ * (reduction indices)
+ */
+bool convert_as_mean(loco::TensorReduce *origin)
+{
+ EXO_ASSERT(origin->func() == loco::ReduceFunc::Mean, "func should be Mean for this helper");
+ EXO_ASSERT(origin->input(), "TensorReduce has no input");
+
+ auto *graph = origin->graph();
+
+ // Make reduction indicies TFLConst node
+ auto reduction = graph->nodes()->create<locoex::TFLConst>();
+ {
+ auto input_rank = loco::shape_get(origin->input()).as<loco::TensorShape>().rank();
+
+ std::vector<int32_t> red_vec;
+ for (uint32_t axis = 0; axis < input_rank; ++axis)
+ if (origin->axes()->defined(axis))
+ red_vec.push_back(static_cast<int32_t>(axis));
+
+ const loco::DataType S32 = loco::DataType::S32;
+
+ reduction->dtype(S32);
+ reduction->rank(1);
+ reduction->dim(0) = red_vec.size();
+ reduction->size<S32>(red_vec.size());
+ for (uint32_t i = 0; i < red_vec.size(); ++i)
+ reduction->at<S32>(i) = red_vec.at(i);
+ }
+
+ // Make TFLMean node to replace
+ auto mean = graph->nodes()->create<locoex::TFLMean>();
+ mean->input(origin->input());
+ mean->reduction_indices(reduction);
+ mean->keep_dims(true); // Canonical TensorReduce always keep dimensions
+
+ // replace canonical node
+ loco::replace(origin).with(mean);
+ origin->input(nullptr);
+
+ return true;
+}
+
+} // namespace
+
+namespace exo
+{
+
+bool TensorReduceConverter::convert(loco::TensorReduce *origin)
+{
+ if (origin->func() == loco::ReduceFunc::Mean)
+ return convert_as_mean(origin);
+ else
+ INTERNAL_EXN_V("Unsupported ReduceFunc", oops::to_uint32(origin->func()));
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/TensorReduceConverter.h b/compiler/exo/src/Conversion/TensorReduceConverter.h
new file mode 100644
index 000000000..dfd65ad2d
--- /dev/null
+++ b/compiler/exo/src/Conversion/TensorReduceConverter.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TENSOR_REDUCE_CONVERTER__
+#define __TENSOR_REDUCE_CONVERTER__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::TensorReduce to appropriate TFL reduce operation
+ * @note loco::TensorReduce always keep dimensions
+ *
+ * Currently support:
+ * - When loco::TensorReduce::func() == Mean, convert to TFLMean + TFLConst
+ * - TODO Support other cases
+ */
+class TensorReduceConverter : public CanonicalNodeConverter<loco::TensorReduce>
+{
+public:
+ const char *name(void) const final { return "exo::TensorReduceConverter"; }
+
+public:
+ bool convert(loco::TensorReduce *origin) final;
+};
+
+} // namespace exo
+
+#endif // __TENSOR_REDUCE_CONVERTER__
diff --git a/compiler/exo/src/Conversion/TensorTransposeConverter.cpp b/compiler/exo/src/Conversion/TensorTransposeConverter.cpp
new file mode 100644
index 000000000..25c27fe7e
--- /dev/null
+++ b/compiler/exo/src/Conversion/TensorTransposeConverter.cpp
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TensorTransposeConverter.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include <loco.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <oops/InternalExn.h>
+
+#include <algorithm>
+#include <cassert>
+#include <vector>
+
+namespace
+{
+
+void validate_perm(loco::TensorTranspose *origin)
+{
+ // check perm values are correct
+ std::vector<uint32_t> base_perms; // such as {0, 1, 2, 3, ... }
+ std::vector<uint32_t> perms; // perm values in TensorTranspose
+
+ base_perms.resize(origin->perm()->size());
+ perms.resize(origin->perm()->size());
+ for (loco::TensorAxis x = 0; x < origin->perm()->size(); x++)
+ {
+ base_perms[x] = x;
+ perms[x] = origin->perm()->axis(x);
+ }
+
+ if (!std::is_permutation(base_perms.begin(), base_perms.end(), perms.begin()))
+ INTERNAL_EXN("wrong perm value");
+}
+
+} // namespace
+
+namespace exo
+{
+/**
+ * @brief Converts loco::TensorTranspose to locoex::TFLTranspose
+ */
+bool TensorTransposeConverter::convert(loco::TensorTranspose *origin)
+{
+ auto *graph = origin->graph();
+
+ auto tfl_transpose = graph->nodes()->create<locoex::TFLTranspose>();
+ {
+ // validation
+ {
+ assert(origin->input() != nullptr);
+
+ auto input_rank = loco::shape_get(origin->input()).as<loco::TensorShape>().rank();
+ if (input_rank != origin->perm()->size())
+ INTERNAL_EXN_V("perm size should be same with input rank",
+ oops::to_uint32(origin->perm()->size()));
+
+ validate_perm(origin);
+ }
+
+ tfl_transpose->a(origin->input());
+
+ // perm : set TFLConst
+ auto perm_const = graph->nodes()->create<locoex::TFLConst>();
+ {
+ perm_const->dtype(loco::DataType::S32);
+ perm_const->rank(1);
+ perm_const->dim(0) = origin->perm()->size();
+ perm_const->size<loco::DataType::S32>(origin->perm()->size());
+
+ // add perm values into perm TFLConst
+ for (loco::TensorAxis x = 0; x < origin->perm()->size(); x++)
+ {
+ perm_const->at<loco::DataType::S32>(x) = origin->perm()->axis(x);
+ }
+ }
+ tfl_transpose->perm(perm_const);
+ }
+
+ // replace canonical node
+ loco::replace(origin).with(tfl_transpose);
+ origin->input(nullptr);
+
+ return true;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/TensorTransposeConverter.h b/compiler/exo/src/Conversion/TensorTransposeConverter.h
new file mode 100644
index 000000000..9b61ff38d
--- /dev/null
+++ b/compiler/exo/src/Conversion/TensorTransposeConverter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_TENSORTRANSPOSE_CONVERTER__
+#define __CONVERSION_TENSORTRANSPOSE_CONVERTER__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::TensorTranspose to locoex::TFLTranspose
+ */
+class TensorTransposeConverter : public CanonicalNodeConverter<loco::TensorTranspose>
+{
+public:
+ const char *name(void) const final { return "exo::TensorTransposeConverter"; }
+
+public:
+ bool convert(loco::TensorTranspose *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_TENSORTRANSPOSE_CONVERTER__
diff --git a/compiler/exo/src/Conversion/TransposedConv2DConverter.cpp b/compiler/exo/src/Conversion/TransposedConv2DConverter.cpp
new file mode 100644
index 000000000..c03b64f48
--- /dev/null
+++ b/compiler/exo/src/Conversion/TransposedConv2DConverter.cpp
@@ -0,0 +1,92 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TransposedConv2DConverter.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include "GraphBlock.h"
+
+#include <loco.h>
+#include <loco/Service/ShapeInference.h>
+
+namespace exo
+{
+
+bool TransposedConv2DConverter::convert(loco::TransposedConv2D *origin)
+{
+ // Shape is required to set origin->inputSizes()
+ if (not loco::shape_known(origin))
+ return false;
+
+ if ((origin->ifm() == nullptr) or (origin->ker() == nullptr))
+ return false;
+
+ auto *graph = origin->graph();
+
+ auto tfl_tr_conv = graph->nodes()->create<locoex::TFLTransposeConv>();
+ {
+ tfl_tr_conv->stride()->w(origin->stride()->horizontal());
+ tfl_tr_conv->stride()->h(origin->stride()->vertical());
+
+ auto pad = origin->pad();
+ if (pad->bottom() == 0 && pad->top() == 0 && pad->left() == 0 && pad->right() == 0)
+ tfl_tr_conv->padding(locoex::Padding::VALID);
+ else
+ // TODO This is necessary, but not sufficient condition. More rigorous check required
+ tfl_tr_conv->padding(locoex::Padding::SAME);
+ }
+
+ // let's create a new graph connection with tfl_tr_conv
+ {
+ // Make inputSizes from shape of origin
+ auto input_sizes_const = graph->nodes()->create<locoex::TFLConst>();
+ auto origin_shape = loco::shape_get(origin).as<loco::FeatureShape>();
+
+ const loco::DataType S32 = loco::DataType::S32;
+
+ input_sizes_const->dtype(S32);
+ input_sizes_const->rank(1);
+ input_sizes_const->dim(0) = 4;
+ input_sizes_const->size<S32>(4);
+ // Note that NHWC is layout for inputSizes determined by tflite format
+ input_sizes_const->at<S32>(0) = origin_shape.count().value(); // N
+ input_sizes_const->at<S32>(1) = origin_shape.height().value(); // H
+ input_sizes_const->at<S32>(2) = origin_shape.width().value(); // W
+ input_sizes_const->at<S32>(3) = origin_shape.depth().value(); // C
+
+ tfl_tr_conv->inputSizes(input_sizes_const);
+
+ // filter
+ auto filter_dec = make_filter_decode<FilterLayout::OHWI>(origin->ker());
+ tfl_tr_conv->filter(filter_dec);
+
+ // outBackprop
+ auto feature_dec = make_feature_decode<FeatureLayout::NHWC>(origin->ifm());
+ tfl_tr_conv->outBackprop(feature_dec);
+
+ // output
+ auto feature_enc = make_feature_encode<FeatureLayout::NHWC>(tfl_tr_conv);
+
+ // replace canonical node
+ loco::replace(origin).with(feature_enc);
+ origin->ifm(nullptr);
+ }
+
+ return true;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Conversion/TransposedConv2DConverter.h b/compiler/exo/src/Conversion/TransposedConv2DConverter.h
new file mode 100644
index 000000000..f51e0a5bc
--- /dev/null
+++ b/compiler/exo/src/Conversion/TransposedConv2DConverter.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSION_TRANSPOSEDCONV2D_CONVERTER__
+#define __CONVERSION_TRANSPOSEDCONV2D_CONVERTER__
+
+#include "CanonicalNodeConverter.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Convert loco::TransposedConv2D to locoex::TFLTransposeConv and auxiliary
+ *
+ *
+ * <BEFORE>
+ *
+ * IFM ------- TransposedConv2D --- OFM
+ * (Feature) / (Feature)
+ * /
+ * KER ------
+ * (Filter)
+ *
+ *
+ * <AFTER>
+ *
+ * out_backprop : IFM ------- FeatureDecode --- TFLTransposeConv --- FeatureEncode --- OFM
+ * [Feature] [Tensor] / / [Tensor] [Feature]
+ * / /
+ * filter: KER ------- FilterDecode --- /
+ * [Filter] [Tensor] /
+ * /
+ * input_sizes : TFLConst (new) ------------
+ * [Tensor]
+ */
+class TransposedConv2DConverter : public CanonicalNodeConverter<loco::TransposedConv2D>
+{
+public:
+ const char *name(void) const final { return "exo::TransposedConv2DConverter"; }
+
+public:
+ bool convert(loco::TransposedConv2D *origin) final;
+};
+
+} // namespace exo
+
+#endif // __CONVERSION_TRANSPOSEDCONV2D_CONVERTER__
diff --git a/compiler/exo/src/Conversions.h b/compiler/exo/src/Conversions.h
new file mode 100644
index 000000000..8eb4ed2e4
--- /dev/null
+++ b/compiler/exo/src/Conversions.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERSIONS_H__
+#define __CONVERSIONS_H__
+
+#include "Conversion/AvgPool2DConverter.h"
+#include "Conversion/ConstGenConverter.h"
+#include "Conversion/Conv2DConverter.h"
+#include "Conversion/DepthwiseConv2DConverter.h"
+// TODO loco::DepthwiseFilterEncode
+#include "Conversion/EltwiseAddConverter.h"
+#include "Conversion/EltwiseDivConverter.h"
+#include "Conversion/EltwiseMaxConverter.h"
+#include "Conversion/EltwiseMulConverter.h"
+#include "Conversion/EltwiseSqrtConverter.h"
+#include "Conversion/EltwiseSubConverter.h"
+#include "Conversion/FeatureBiasAddConverter.h"
+// TODO loco::FixedReshape
+#include "Conversion/MatMulConverter.h"
+#include "Conversion/MaxPool2DConverter.h"
+#include "Conversion/ReluConverter.h"
+#include "Conversion/Relu6Converter.h"
+// TODO loco::Tanh
+#include "Conversion/TensorConcatConverter.h"
+// TODO loco::TensorBiasAdd
+#include "Conversion/TensorBroadcastConverter.h"
+#include "Conversion/TensorReduceConverter.h"
+// TODO loco::TensorSoftmax
+#include "Conversion/TensorTransposeConverter.h"
+#include "Conversion/TransposedConv2DConverter.h"
+
+#endif // __CONVERSIONS_H__
diff --git a/compiler/exo/src/Convert.cpp b/compiler/exo/src/Convert.cpp
new file mode 100644
index 000000000..45f0481f4
--- /dev/null
+++ b/compiler/exo/src/Convert.cpp
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Convert.h"
+
+#include "Conversions.h"
+#include "Pass/ShapeInferencePass.h"
+#include "Pass/TypeInferencePass.h"
+#include "ProgressReporter.h"
+#include "Knob.h"
+
+#include <loco.h>
+#include <loco/Service/ShapeInference.h>
+#include <loco/Service/CanonicalShapeInferenceRule.h>
+#include <loco/Service/TypeInference.h>
+
+#include <logo/SimplifyDomainConversionPass.h>
+#include <logo/RemoveDeadNodePass.h>
+#include <logo/RemoveForwardNodePass.h>
+
+#include <logo/Phase.h>
+#include <stdex/Memory.h>
+
+namespace exo
+{
+
+void convert_to_TFLNodes(loco::Graph *graph)
+{
+ // run Shape and Type inference must be run before conversion
+ loco::CanonicalShapeInferenceRule shape_rule;
+ loco::apply(&shape_rule).to(graph);
+
+ loco::CanonicalTypeInferenceRule type_rule;
+ loco::apply(&type_rule).to(graph);
+
+ logo::Phase phase;
+ {
+ // prepare type and shape before conversion
+ phase.emplace_back(stdex::make_unique<TypeInferencePass>());
+ phase.emplace_back(stdex::make_unique<ShapeInferencePass>());
+
+ // Add converters for canonical nodes. Note: Not all loco canonical nodes are listed.
+ phase.emplace_back(stdex::make_unique<AvgPool2DConverter>());
+ phase.emplace_back(stdex::make_unique<ConstGenConverter>());
+ phase.emplace_back(stdex::make_unique<Conv2DConverter>());
+ phase.emplace_back(stdex::make_unique<DepthwiseConv2DConverter>());
+ // TODO loco::DepthwiseFilterEncode
+ phase.emplace_back(stdex::make_unique<EltwiseAddConverter>());
+ phase.emplace_back(stdex::make_unique<EltwiseDivConverter>());
+ phase.emplace_back(stdex::make_unique<EltwiseMaxConverter>());
+ phase.emplace_back(stdex::make_unique<EltwiseMulConverter>());
+ phase.emplace_back(stdex::make_unique<EltwiseSqrtConverter>());
+ phase.emplace_back(stdex::make_unique<EltwiseSubConverter>());
+ phase.emplace_back(stdex::make_unique<FeatureBiasAddConverter>());
+ // TODO loco::FixedReshape
+ phase.emplace_back(stdex::make_unique<MatMulConverter>());
+ phase.emplace_back(stdex::make_unique<MaxPool2DConverter>());
+ phase.emplace_back(stdex::make_unique<ReluConverter>());
+ phase.emplace_back(stdex::make_unique<Relu6Converter>());
+ // TODO loco::Tanh
+ phase.emplace_back(stdex::make_unique<TensorConcatConverter>());
+ // TODO loco::TensorBiasAdd
+ phase.emplace_back(stdex::make_unique<TensorBroadcastConverter>());
+ phase.emplace_back(stdex::make_unique<TensorReduceConverter>());
+ // TODO loco::TensorSoftmax
+ phase.emplace_back(stdex::make_unique<TensorTransposeConverter>());
+ phase.emplace_back(stdex::make_unique<TransposedConv2DConverter>());
+
+ // Add optimization below
+ phase.emplace_back(stdex::make_unique<logo::SimplifyDomainConversionPass>());
+ phase.emplace_back(stdex::make_unique<logo::RemoveForwardNodePass>());
+ phase.emplace_back(stdex::make_unique<logo::RemoveDeadNodePass>());
+ }
+
+ logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{graph};
+
+ ProgressReporter prog(graph, logo::PhaseStrategy::Restart);
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+
+ // TODO Assert if all canonical nodes are converted to TFL node
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Convert.h b/compiler/exo/src/Convert.h
new file mode 100644
index 000000000..7038f9cf7
--- /dev/null
+++ b/compiler/exo/src/Convert.h
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVERT_H__
+#define __CONVERT_H__
+
+#include <loco.h>
+
+namespace exo
+{
+
+void convert_to_TFLNodes(loco::Graph *graph);
+
+} // namespace exo
+
+#endif // __CONVERT_H__
diff --git a/compiler/exo/src/Dialect/IR/CircleDialect.cpp b/compiler/exo/src/Dialect/IR/CircleDialect.cpp
new file mode 100644
index 000000000..ecd43b0a3
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleDialect.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleDialect.h"
+
+namespace locoex
+{
+
+loco::Dialect *CircleDialect::get(void)
+{
+ static CircleDialect d;
+ return &d;
+}
+
+} // namespace locoex
diff --git a/compiler/exo/src/Dialect/IR/CircleDialect.h b/compiler/exo/src/Dialect/IR/CircleDialect.h
new file mode 100644
index 000000000..9857d9e6d
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleDialect.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_CIRCLEDIALECT_H__
+#define __LOCOEX_IR_CIRCLEDIALECT_H__
+
+#include <loco/IR/Dialect.h>
+
+namespace locoex
+{
+
+class CircleDialect final : public loco::Dialect
+{
+private:
+ CircleDialect() = default;
+
+public:
+ CircleDialect(const CircleDialect &) = delete;
+ CircleDialect(CircleDialect &&) = delete;
+
+public:
+ static loco::Dialect *get(void);
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_CIRCLEDIALECT_H__
diff --git a/compiler/exo/src/Dialect/IR/CircleDialect.test.cpp b/compiler/exo/src/Dialect/IR/CircleDialect.test.cpp
new file mode 100644
index 000000000..6132eb361
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleDialect.test.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleDialectTest, get)
+{
+ using locoex::CircleDialect;
+
+ auto d = CircleDialect::get();
+
+ // get() SHOULD return a valid(non-null) pointer
+ ASSERT_NE(d, nullptr);
+ // The return value SHOULD be stable across multiple invocations
+ ASSERT_EQ(d, CircleDialect::get());
+}
diff --git a/compiler/exo/src/Dialect/IR/CircleNode.cpp b/compiler/exo/src/Dialect/IR/CircleNode.cpp
new file mode 100644
index 000000000..cdcd434ea
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleNode.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleNode.h"
+
+#include "CircleDialect.h"
+
+namespace locoex
+{
+
+const loco::Dialect *CircleNode::dialect(void) const { return CircleDialect::get(); }
+
+} // namespace locoex
diff --git a/compiler/exo/src/Dialect/IR/CircleNode.h b/compiler/exo/src/Dialect/IR/CircleNode.h
new file mode 100644
index 000000000..1ae9d38bd
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleNode.h
@@ -0,0 +1,23 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_CIRCLENODE_H__
+#define __LOCOEX_IR_CIRCLENODE_H__
+
+#include "CircleNodeDecl.h"
+#include "CircleNodeImpl.h"
+
+#endif // __LOCOEX_IR_CIRCLENODE_H__
diff --git a/compiler/exo/src/Dialect/IR/CircleNodeDecl.h b/compiler/exo/src/Dialect/IR/CircleNodeDecl.h
new file mode 100644
index 000000000..358b1f0ce
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleNodeDecl.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_CIRCLENODEDECL_H__
+#define __LOCOEX_IR_CIRCLENODEDECL_H__
+
+#include <loco/IR/Node.h>
+#include <loco/IR/Dialect.h>
+
+#include "CircleOpcode.h"
+#include "CircleNodeVisitor.forward.h"
+
+namespace locoex
+{
+
+struct CircleNode : public loco::Node
+{
+ virtual ~CircleNode() = default;
+
+ const loco::Dialect *dialect(void) const final;
+ virtual CircleOpcode opcode(void) const = 0;
+
+ template <typename T> T accept(CircleNodeVisitorBase<T> *) const;
+ template <typename T> T accept(CircleNodeMutableVisitorBase<T> *);
+};
+
+template <CircleOpcode Code> struct CircleNodeImpl : public CircleNode
+{
+ virtual ~CircleNodeImpl() = default;
+
+ uint32_t opnum(void) const final { return static_cast<uint32_t>(Code); }
+ CircleOpcode opcode(void) const final { return Code; }
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_CIRCLENODEDECL_H__
diff --git a/compiler/exo/src/Dialect/IR/CircleNodeImpl.h b/compiler/exo/src/Dialect/IR/CircleNodeImpl.h
new file mode 100644
index 000000000..d9f487111
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleNodeImpl.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_CIRCLENODEIMPL_H__
+#define __LOCOEX_IR_CIRCLENODEIMPL_H__
+
+#include "CircleNodes.h"
+#include "CircleNodeVisitor.h"
+
+#include <oops/InternalExn.h>
+
+#include <cassert>
+
+namespace locoex
+{
+
+template <typename T> T CircleNode::accept(CircleNodeVisitorBase<T> *v) const
+{
+ switch (this->opcode())
+ {
+#define CIRCLE_NODE(OPCODE, CLASS) \
+ \
+ case CircleOpcode::OPCODE: \
+ return v->visit(dynamic_cast<const CLASS *>(this));
+
+#include "CircleNodes.lst"
+#undef CIRCLE_NODE
+
+ default:
+ break;
+ }
+
+ INTERNAL_EXN("CircleNode::accept(CircleNodeVisitorBase) not handled");
+}
+
+template <typename T> T CircleNode::accept(CircleNodeMutableVisitorBase<T> *v)
+{
+ switch (this->opcode())
+ {
+#define CIRCLE_NODE(OPCODE, CLASS) \
+ \
+ case CircleOpcode::OPCODE: \
+ return v->visit(dynamic_cast<CLASS *>(this));
+
+#include "CircleNodes.lst"
+#undef CIRCLE_NODE
+
+ default:
+ break;
+ }
+
+ INTERNAL_EXN("CircleNode::accept(CircleNodeMutableVisitorBase) not handled");
+}
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_CIRCLENODEIMPL_H__
diff --git a/compiler/exo/src/Dialect/IR/CircleNodeVisitor.forward.h b/compiler/exo/src/Dialect/IR/CircleNodeVisitor.forward.h
new file mode 100644
index 000000000..8ae28abf3
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleNodeVisitor.forward.h
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_CIRCLENODE_VISITOR_FORWARD_H__
+#define __LOCOEX_IR_CIRCLENODE_VISITOR_FORWARD_H__
+
+namespace locoex
+{
+
+// NOTE These forward declarations SHOULD BE aligned with Node delcarations in
+// "CircleNodeVisitor.h"
+template <typename T> struct CircleNodeVisitorBase;
+template <typename T> struct CircleNodeMutableVisitorBase;
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_CIRCLENODE_VISITOR_FORWARD_H__
diff --git a/compiler/exo/src/Dialect/IR/CircleNodeVisitor.h b/compiler/exo/src/Dialect/IR/CircleNodeVisitor.h
new file mode 100644
index 000000000..fc70c9ebc
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleNodeVisitor.h
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_CIRCLENODE_VISITOR_H__
+#define __LOCOEX_IR_CIRCLENODE_VISITOR_H__
+
+#include "CircleNode.h"
+#include "CircleNodes.h"
+
+#include <oops/InternalExn.h>
+
+namespace locoex
+{
+
+/**
+ * DO NOT use this class. Use CircleNodeVisitor instead.
+ */
+template <typename T> struct CircleNodeVisitorBase
+{
+ virtual ~CircleNodeVisitorBase() = default;
+
+#define CIRCLE_NODE(OPCODE, Circle_CLASS) virtual T visit(const Circle_CLASS *) = 0;
+
+#include "CircleNodes.lst"
+#undef CIRCLE_NODE
+};
+
+template <typename T> struct CircleNodeVisitor : public CircleNodeVisitorBase<T>
+{
+ virtual ~CircleNodeVisitor() = default;
+
+#define CIRCLE_NODE(OPCODE, Circle_CLASS) \
+ \
+ virtual T visit(const Circle_CLASS *node) { return visit(static_cast<const CircleNode *>(node)); }
+
+#include "CircleNodes.lst"
+#undef CIRCLE_NODE
+
+ /// @brief Default fallback
+ virtual T visit(const CircleNode *) { INTERNAL_EXN("CircleNodeVisistor: NYI node"); }
+};
+
+/**
+ * DO NOT use this class. Use CircleNodeMutableVisitor instead.
+ */
+template <typename T> struct CircleNodeMutableVisitorBase
+{
+ virtual ~CircleNodeMutableVisitorBase() = default;
+
+#define CIRCLE_NODE(OPCODE, Circle_CLASS) virtual T visit(Circle_CLASS *) = 0;
+
+#include "CircleNodes.lst"
+#undef CIRCLE_NODE
+};
+
+template <typename T> struct CircleNodeMutableVisitor : public CircleNodeMutableVisitorBase<T>
+{
+ virtual ~CircleNodeMutableVisitor() = default;
+
+#define CIRCLE_NODE(OPCODE, Circle_CLASS) \
+ \
+ virtual T visit(Circle_CLASS *node) { return visit(static_cast<CircleNode *>(node)); }
+
+#include "CircleNodes.lst"
+#undef CIRCLE_NODE
+
+ /// @brief Default fallback
+ virtual T visit(CircleNode *) { INTERNAL_EXN("CircleMutableNodeVisistor: NYI node"); }
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_CIRCLENODE_VISITOR_H__
diff --git a/compiler/exo/src/Dialect/IR/CircleNodes.cpp b/compiler/exo/src/Dialect/IR/CircleNodes.cpp
new file mode 100644
index 000000000..bba59ff4d
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleNodes.cpp
@@ -0,0 +1,18 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// This is to validate CircleNodes.h
+#include "CircleNodes.h"
diff --git a/compiler/exo/src/Dialect/IR/CircleNodes.h b/compiler/exo/src/Dialect/IR/CircleNodes.h
new file mode 100644
index 000000000..7be093103
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleNodes.h
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_CIRCLENODES_H__
+#define __LOCOEX_IR_CIRCLENODES_H__
+
+#include "CircleNodeDecl.h"
+#include "CircleOpcode.h"
+
+#include "FusedActFunc.h"
+#include "NodeMixins.h" // FixedArityNode
+
+#include <loco/IR/Node.h>
+
+namespace locoex
+{
+
+/// @brief enumeration of mixin class
+enum class CircleNodeTrait
+{
+ FusedActFunc,
+};
+
+template <CircleNodeTrait T> class CircleNodeMixin;
+
+template <> class CircleNodeMixin<CircleNodeTrait::FusedActFunc>
+{
+public:
+ CircleNodeMixin() = default;
+
+public:
+ FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
+ void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }
+
+private:
+ FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED;
+};
+
+/**
+ * @brief INSTANCE_NORM in circle
+ */
+class CircleInstanceNorm final
+ : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::INSTANCE_NORM>>,
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
+{
+public:
+ /// @note Currently only support FLOAT32 as input node
+ loco::Node *input(void) const { return at(0)->node(); }
+ void input(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *gamma(void) const { return at(1)->node(); }
+ void gamma(loco::Node *node) { at(1)->node(node); }
+
+ loco::Node *beta(void) const { return at(2)->node(); }
+ void beta(loco::Node *node) { at(2)->node(node); }
+
+ float epsilon() const { return _epsilon; }
+ void epsilon(float epsilon) { _epsilon = epsilon; }
+
+private:
+ float _epsilon = 1e-05;
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_CIRCLENODES_H__
diff --git a/compiler/exo/src/Dialect/IR/CircleNodes.lst b/compiler/exo/src/Dialect/IR/CircleNodes.lst
new file mode 100644
index 000000000..96baf2917
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleNodes.lst
@@ -0,0 +1,8 @@
+#ifndef CIRCLE_NODE
+#error "Define CIRCLE_NODE"
+#endif // CIRCLE_NODE
+
+//
+// PLEASE SORT NODE DECLS IN ALPHABETICAL ORDER
+//
+CIRCLE_NODE(INSTANCE_NORM, locoex::CircleInstanceNorm)
diff --git a/compiler/exo/src/Dialect/IR/CircleNodes.test.cpp b/compiler/exo/src/Dialect/IR/CircleNodes.test.cpp
new file mode 100644
index 000000000..b63e7ccae
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleNodes.test.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleNodes.h"
+
+#include "CircleDialect.h"
+#include "CircleOpcode.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleInstanceNormTest, constructor)
+{
+ locoex::CircleInstanceNorm instance_norm;
+
+ ASSERT_EQ(instance_norm.dialect(), locoex::CircleDialect::get());
+ ASSERT_EQ(instance_norm.opcode(), locoex::CircleOpcode::INSTANCE_NORM);
+
+ ASSERT_EQ(instance_norm.input(), nullptr);
+ ASSERT_EQ(instance_norm.gamma(), nullptr);
+ ASSERT_EQ(instance_norm.beta(), nullptr);
+ ASSERT_FLOAT_EQ(instance_norm.epsilon(), 1e-05);
+ ASSERT_EQ(instance_norm.fusedActivationFunction(), locoex::FusedActFunc::UNDEFINED);
+}
diff --git a/compiler/exo/src/Dialect/IR/CircleOpcode.h b/compiler/exo/src/Dialect/IR/CircleOpcode.h
new file mode 100644
index 000000000..264304049
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/CircleOpcode.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_CIRCLEOPCODE_H__
+#define __LOCOEX_IR_CIRCLEOPCODE_H__
+
+namespace locoex
+{
+
+enum class CircleOpcode
+{
+#define CIRCLE_NODE(OPCODE, CLASS) OPCODE,
+#include "CircleNodes.lst"
+#undef CIRCLE_NODE
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_CIRCLEOPCODE_H__
diff --git a/compiler/exo/src/Dialect/IR/FusedActFunc.h b/compiler/exo/src/Dialect/IR/FusedActFunc.h
new file mode 100644
index 000000000..b73a0799e
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/FusedActFunc.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __DIALECT_IR_FUSEDACTFUNC_H__
+#define __DIALECT_IR_FUSEDACTFUNC_H__
+
+namespace locoex
+{
+
+// TODO Divide into TFL version and Circle version when they go different approach
+enum class FusedActFunc
+{
+ UNDEFINED, // This is not defined by TFLite or Circle. This was added to
+ // prevent programming error.
+ NONE,
+ RELU,
+ RELU6
+};
+
+} // namespace locoex
+
+#endif // __DIALECT_IR_FUSEDACTFUNC_H__
diff --git a/compiler/exo/src/Dialect/IR/NodeMixins.cpp b/compiler/exo/src/Dialect/IR/NodeMixins.cpp
new file mode 100644
index 000000000..cdfe0d8d1
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/NodeMixins.cpp
@@ -0,0 +1,18 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// This is to validate NodeMixins.h
+#include "NodeMixins.h"
diff --git a/compiler/exo/src/Dialect/IR/NodeMixins.h b/compiler/exo/src/Dialect/IR/NodeMixins.h
new file mode 100644
index 000000000..c35daebc6
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/NodeMixins.h
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __DIALECT_IR_NODEMIXINS_H__
+#define __DIALECT_IR_NODEMIXINS_H__
+
+#include <loco/IR/Node.h>
+
+namespace locoex
+{
+
+/**
+ * @brief Nodes with the fixed number of inputs
+ *
+ * TODO Deprecated this class, and use loco::FixedArity instead
+ */
+template <unsigned N, typename Base> class FixedArityNode : public Base
+{
+public:
+ FixedArityNode()
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _args[n] = std::unique_ptr<loco::Use>(new loco::Use{this});
+ }
+ }
+
+ virtual ~FixedArityNode() = default;
+
+public:
+ unsigned arity(void) const final { return N; }
+
+ loco::Node *arg(uint32_t n) const final { return _args.at(n)->node(); }
+
+ void drop(void) final
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _args.at(n)->node(nullptr);
+ }
+ }
+
+protected:
+ // This API allows inherited classes to access "_args" field.
+ loco::Use *at(unsigned n) const { return _args.at(n).get(); }
+
+private:
+ std::array<std::unique_ptr<loco::Use>, N> _args;
+};
+
+} // namespace locoex
+
+#endif // __DIALECT_IR_NODEMIXINS_H__
diff --git a/compiler/exo/src/Dialect/IR/TFLDialect.cpp b/compiler/exo/src/Dialect/IR/TFLDialect.cpp
new file mode 100644
index 000000000..8cbf9a364
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLDialect.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLDialect.h"
+
+namespace locoex
+{
+
+loco::Dialect *TFLDialect::get(void)
+{
+ static TFLDialect d;
+ return &d;
+}
+
+} // namespace locoex
diff --git a/compiler/exo/src/Dialect/IR/TFLDialect.h b/compiler/exo/src/Dialect/IR/TFLDialect.h
new file mode 100644
index 000000000..96463a9f9
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLDialect.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_TFLDIALECT_H__
+#define __LOCOEX_IR_TFLDIALECT_H__
+
+#include <loco/IR/Dialect.h>
+
+namespace locoex
+{
+
+class TFLDialect final : public loco::Dialect
+{
+private:
+ TFLDialect() = default;
+
+public:
+ TFLDialect(const TFLDialect &) = delete;
+ TFLDialect(TFLDialect &&) = delete;
+
+public:
+ static loco::Dialect *get(void);
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_TFLDIALECT_H__
diff --git a/compiler/exo/src/Dialect/IR/TFLDialect.test.cpp b/compiler/exo/src/Dialect/IR/TFLDialect.test.cpp
new file mode 100644
index 000000000..136721e2d
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLDialect.test.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFLDialectTest, get)
+{
+ using locoex::TFLDialect;
+
+ auto d = TFLDialect::get();
+
+ // get() SHOULD return a valid(non-null) pointer
+ ASSERT_NE(d, nullptr);
+ // The return value SHOULD be stable across multiple invocations
+ ASSERT_EQ(d, TFLDialect::get());
+}
diff --git a/compiler/exo/src/Dialect/IR/TFLNode.cpp b/compiler/exo/src/Dialect/IR/TFLNode.cpp
new file mode 100644
index 000000000..82d5f1eba
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLNode.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLNode.h"
+
+#include "TFLDialect.h"
+
+namespace locoex
+{
+
+const loco::Dialect *TFLNode::dialect(void) const { return TFLDialect::get(); }
+
+} // namespace locoex
diff --git a/compiler/exo/src/Dialect/IR/TFLNode.h b/compiler/exo/src/Dialect/IR/TFLNode.h
new file mode 100644
index 000000000..eff69b1a5
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLNode.h
@@ -0,0 +1,23 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_TFLNODE_H__
+#define __LOCOEX_IR_TFLNODE_H__
+
+#include "TFLNodeDecl.h"
+#include "TFLNodeImpl.h"
+
+#endif // __LOCOEX_IR_TFLNODE_H__
diff --git a/compiler/exo/src/Dialect/IR/TFLNodeDecl.h b/compiler/exo/src/Dialect/IR/TFLNodeDecl.h
new file mode 100644
index 000000000..d13900ab3
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLNodeDecl.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_TFLNODEDECL_H__
+#define __LOCOEX_IR_TFLNODEDECL_H__
+
+#include <loco/IR/Node.h>
+#include <loco/IR/Dialect.h>
+
+#include "TFLOpcode.h"
+#include "TFLNodeVisitor.forward.h"
+
+namespace locoex
+{
+
+struct TFLNode : public loco::Node
+{
+ virtual ~TFLNode() = default;
+
+ const loco::Dialect *dialect(void) const final;
+ virtual TFLOpcode opcode(void) const = 0;
+
+ template <typename T> T accept(TFLNodeVisitorBase<T> *) const;
+ template <typename T> T accept(TFLNodeMutableVisitorBase<T> *);
+};
+
+template <TFLOpcode Code> struct TFLNodeImpl : public TFLNode
+{
+ virtual ~TFLNodeImpl() = default;
+
+ uint32_t opnum(void) const final { return static_cast<uint32_t>(Code); }
+ TFLOpcode opcode(void) const final { return Code; }
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_TFLNODEDECL_H__
diff --git a/compiler/exo/src/Dialect/IR/TFLNodeImpl.h b/compiler/exo/src/Dialect/IR/TFLNodeImpl.h
new file mode 100644
index 000000000..63388279a
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLNodeImpl.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_TFLNODEIMPL_H__
+#define __LOCOEX_IR_TFLNODEIMPL_H__
+
+#include "TFLNodes.h"
+#include "TFLNodeVisitor.h"
+
+#include <oops/InternalExn.h>
+
+#include <cassert>
+
+namespace locoex
+{
+
+template <typename T> T TFLNode::accept(TFLNodeVisitorBase<T> *v) const
+{
+ switch (this->opcode())
+ {
+#define TFL_NODE(OPCODE, CLASS) \
+ \
+ case TFLOpcode::OPCODE: \
+ return v->visit(dynamic_cast<const CLASS *>(this));
+
+#include "TFLNodes.lst"
+#undef TFL_NODE
+
+ default:
+ break;
+ }
+
+ INTERNAL_EXN("TFLNode::accept(TFLNodeVisitorBase) not handled");
+}
+
+template <typename T> T TFLNode::accept(TFLNodeMutableVisitorBase<T> *v)
+{
+ switch (this->opcode())
+ {
+#define TFL_NODE(OPCODE, CLASS) \
+ \
+ case TFLOpcode::OPCODE: \
+ return v->visit(dynamic_cast<CLASS *>(this));
+
+#include "TFLNodes.lst"
+#undef TFL_NODE
+
+ default:
+ break;
+ }
+
+ INTERNAL_EXN("TFLNode::accept(TFLNodeMutableVisitorBase) not handled");
+}
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_TFLNODEIMPL_H__
diff --git a/compiler/exo/src/Dialect/IR/TFLNodeVisitor.forward.h b/compiler/exo/src/Dialect/IR/TFLNodeVisitor.forward.h
new file mode 100644
index 000000000..e98057bc3
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLNodeVisitor.forward.h
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_TFLNODE_VISITOR_FORWARD_H__
+#define __LOCOEX_IR_TFLNODE_VISITOR_FORWARD_H__
+
+namespace locoex
+{
+
+// NOTE These forward declarations SHOULD BE aligned with Node delcarations in
+// "TFLNodeVisitor.h"
+template <typename T> struct TFLNodeVisitorBase;
+template <typename T> struct TFLNodeMutableVisitorBase;
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_TFLNODE_VISITOR_FORWARD_H__
diff --git a/compiler/exo/src/Dialect/IR/TFLNodeVisitor.h b/compiler/exo/src/Dialect/IR/TFLNodeVisitor.h
new file mode 100644
index 000000000..e1f5959c0
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLNodeVisitor.h
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_TFLNODE_VISITOR_H__
+#define __LOCOEX_IR_TFLNODE_VISITOR_H__
+
+#include "TFLNode.h"
+#include "TFLNodes.h"
+
+#include <oops/InternalExn.h>
+
+namespace locoex
+{
+
+/**
+ * DO NOT use this class. Use TFLNodeVisitor instead.
+ */
+template <typename T> struct TFLNodeVisitorBase
+{
+ virtual ~TFLNodeVisitorBase() = default;
+
+#define TFL_NODE(OPCODE, TFL_CLASS) virtual T visit(const TFL_CLASS *) = 0;
+
+#include "TFLNodes.lst"
+#undef TFL_NODE
+};
+
+template <typename T> struct TFLNodeVisitor : public TFLNodeVisitorBase<T>
+{
+ virtual ~TFLNodeVisitor() = default;
+
+#define TFL_NODE(OPCODE, TFL_CLASS) \
+ \
+ virtual T visit(const TFL_CLASS *node) { return visit(static_cast<const TFLNode *>(node)); }
+
+#include "TFLNodes.lst"
+#undef TFL_NODE
+
+ /// @brief Default fallback
+ virtual T visit(const TFLNode *) { INTERNAL_EXN("TFLNodeVisitor: NYI node"); }
+};
+
+/**
+ * DO NOT use this class. Use TFLNodeMutableVisitor instead.
+ */
+template <typename T> struct TFLNodeMutableVisitorBase
+{
+ virtual ~TFLNodeMutableVisitorBase() = default;
+
+#define TFL_NODE(OPCODE, TFL_CLASS) virtual T visit(TFL_CLASS *) = 0;
+
+#include "TFLNodes.lst"
+#undef TFL_NODE
+};
+
+template <typename T> struct TFLNodeMutableVisitor : public TFLNodeMutableVisitorBase<T>
+{
+ virtual ~TFLNodeMutableVisitor() = default;
+
+#define TFL_NODE(OPCODE, TFL_CLASS) \
+ \
+ virtual T visit(TFL_CLASS *node) { return visit(static_cast<TFLNode *>(node)); }
+
+#include "TFLNodes.lst"
+#undef TFL_NODE
+
+ /// @brief Default fallback
+ virtual T visit(TFLNode *) { INTERNAL_EXN("TFLNodeMutableVisitor: NYI node"); }
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_TFLNODE_VISITOR_H__
diff --git a/compiler/exo/src/Dialect/IR/TFLNodes.cpp b/compiler/exo/src/Dialect/IR/TFLNodes.cpp
new file mode 100644
index 000000000..f385ce0d9
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLNodes.cpp
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLNodes.h"
+
+#include "Check.h"
+
+#include <loco.h>
+
+#include <cassert>
+
+namespace locoex
+{
+
+template <loco::DataType DT> uint32_t TFLConst::size(void) const
+{
+ assert(dtype() == DT);
+ assert(_data.size() % sizeof(typename loco::DataTypeImpl<DT>::Type) == 0);
+ return _data.size() / sizeof(typename loco::DataTypeImpl<DT>::Type);
+}
+
+template <loco::DataType DT> void TFLConst::size(uint32_t l)
+{
+ assert(dtype() == DT);
+ _data.resize(l * sizeof(typename loco::DataTypeImpl<DT>::Type));
+}
+
+template <loco::DataType DT>
+const typename loco::DataTypeImpl<DT>::Type &TFLConst::at(uint32_t n) const
+{
+ assert(dtype() == DT);
+ assert(n < size<DT>());
+ return *(reinterpret_cast<const typename loco::DataTypeImpl<DT>::Type *>(_data.data()) + n);
+}
+
+template <loco::DataType DT> typename loco::DataTypeImpl<DT>::Type &TFLConst::at(uint32_t n)
+{
+ assert(dtype() == DT);
+ assert(n < size<DT>());
+ return *(reinterpret_cast<typename loco::DataTypeImpl<DT>::Type *>(_data.data()) + n);
+}
+
+#define INSTANTIATE(DT) \
+ template uint32_t TFLConst::size<DT>(void) const; \
+ template void TFLConst::size<DT>(uint32_t); \
+ template const typename loco::DataTypeImpl<DT>::Type &TFLConst::at<DT>(uint32_t) const; \
+ template typename loco::DataTypeImpl<DT>::Type &TFLConst::at<DT>(uint32_t);
+
+INSTANTIATE(loco::DataType::S32);
+INSTANTIATE(loco::DataType::FLOAT32);
+
+#undef INSTANTIATE
+
+void set_new_shape(locoex::TFLReshape *node, int32_t *base, uint32_t size)
+{
+ // Check node does not have both of new shape infos
+ EXO_ASSERT(node->shape() == nullptr, "node already has shape input");
+ EXO_ASSERT(node->newShape()->rank() == 0, "node already has newShape attribute");
+
+ const loco::DataType S32 = loco::DataType::S32;
+
+ // Set 2nd input as TFLConst
+ auto const_shape_node = node->graph()->nodes()->create<locoex::TFLConst>();
+ const_shape_node->rank(1);
+ const_shape_node->dim(0) = size;
+ const_shape_node->dtype(S32);
+ const_shape_node->size<S32>(size);
+ for (uint32_t axis = 0; axis < size; ++axis)
+ const_shape_node->at<S32>(axis) = base[axis];
+ node->shape(const_shape_node);
+
+ // Set newShape attribute
+ node->newShape()->rank(size);
+ for (uint32_t axis = 0; axis < size; ++axis)
+ node->newShape()->dim(axis) = base[axis];
+}
+
+} // namespace locoex
diff --git a/compiler/exo/src/Dialect/IR/TFLNodes.h b/compiler/exo/src/Dialect/IR/TFLNodes.h
new file mode 100644
index 000000000..5f521a0a6
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLNodes.h
@@ -0,0 +1,551 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_TFLNODES_H__
+#define __LOCOEX_IR_TFLNODES_H__
+
+#include "TFLNodeDecl.h"
+#include "TFLOpcode.h"
+
+#include "FusedActFunc.h"
+#include "NodeMixins.h"
+
+#include <loco/IR/Node.h>
+#include <loco/IR/NodeMixins.h>
+#include <loco/IR/DataTypeTraits.h>
+
+#include <locoex/VariadicArityNode.h>
+
+#include <array>
+
+namespace locoex
+{
+
+enum class Padding
+{
+ UNDEFINED, // This is not defined by TFLite. This was added to prevent programming error.
+ SAME,
+ VALID,
+};
+
+class Filter final
+{
+public:
+ Filter() : _w(1), _h(1) {}
+
+ int32_t w() const { return _w; }
+ void w(int32_t w) { _w = w; }
+
+ int32_t h() const { return _h; }
+ void h(int32_t h) { _h = h; }
+
+private:
+ int32_t _w;
+ int32_t _h;
+};
+
+class Stride final
+{
+public:
+ Stride() : _w(1), _h(1) {}
+
+ int32_t w() const { return _w; }
+ void w(int32_t w) { _w = w; }
+
+ int32_t h() const { return _h; }
+ void h(int32_t h) { _h = h; }
+
+private:
+ int32_t _w;
+ int32_t _h;
+};
+
+/// @brief enumeration of mixin class
+enum class TFLNodeTrait
+{
+ FusedActFunc,
+ Bias
+};
+
+template <TFLNodeTrait T> class TFLNodeMixin;
+
+template <> class TFLNodeMixin<TFLNodeTrait::FusedActFunc>
+{
+public:
+ TFLNodeMixin() = default;
+
+public:
+ FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
+ void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }
+
+private:
+ FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED;
+};
+
+/**
+ * @brief Mixin class for nodes that has a bias input
+ */
+template <> class TFLNodeMixin<TFLNodeTrait::Bias>
+{
+public:
+ TFLNodeMixin() = default;
+
+public:
+ virtual loco::Node *bias(void) const = 0; /// @brief get the input for bias.
+ virtual void bias(loco::Node *node) = 0; /// @brief set the input for bias.
+};
+
+/**
+ * @brief ADD in TensorFlow Lite
+ */
+class TFLAdd final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::ADD>>,
+ public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
+{
+public:
+ loco::Node *x(void) const { return at(0)->node(); }
+ void x(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *y(void) const { return at(1)->node(); }
+ void y(loco::Node *node) { at(1)->node(node); }
+};
+
+/**
+ * @brief AVERAGE_POOL_2D in TensorFlow Lite
+ */
+class TFLAveragePool2D final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::AVERAGE_POOL_2D>>,
+ public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
+{
+public:
+ TFLAveragePool2D() : _padding(Padding::UNDEFINED) { /* empty */}
+
+public:
+ loco::Node *value(void) const { return at(0)->node(); }
+ void value(loco::Node *node) { at(0)->node(node); }
+
+ Padding padding() const { return _padding; }
+ void padding(Padding padding) { _padding = padding; }
+
+ const Filter *filter(void) const { return &_filter; }
+ Filter *filter(void) { return &_filter; }
+
+ const Stride *stride(void) const { return &_stride; }
+ Stride *stride(void) { return &_stride; }
+
+private:
+ Padding _padding;
+ Stride _stride;
+ Filter _filter;
+};
+
+/**
+ * @brief CONCATENATION in TensorFlow Lite
+ */
+class TFLConcatenation final : public VariadicArityNode<TFLNodeImpl<TFLOpcode::CONCATENATION>>,
+ public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
+{
+public:
+ TFLConcatenation(uint32_t arity) : VariadicArityNode<TFLNodeImpl<TFLOpcode::CONCATENATION>>(arity)
+ {
+ // TODO Support when arity is 0
+ assert(arity >= 1);
+ }
+
+public:
+ uint32_t numValues(void) const { return arity(); }
+
+public:
+ Node *values(uint32_t index) const
+ {
+ assert(index < numValues());
+ return at(index)->node();
+ }
+ void values(uint32_t index, Node *node)
+ {
+ assert(index < numValues());
+ at(index)->node(node);
+ }
+
+public:
+ uint32_t axis(void) const { return _axis; }
+ void axis(uint32_t axis) { _axis = axis; }
+
+private:
+ uint32_t _axis;
+};
+
+/**
+ * @brief Class to build tensor data
+ * @note This will not be exported as a specific op
+ */
+class TFLConst final : public FixedArityNode<0, TFLNodeImpl<TFLOpcode::CONST>>,
+ public loco::NodeMixin<loco::NodeTrait::DataType>,
+ public loco::NodeMixin<loco::NodeTrait::TensorShape>
+{
+public:
+ TFLConst() = default;
+
+public:
+ template <loco::DataType DT> uint32_t size(void) const;
+ template <loco::DataType DT> void size(uint32_t size);
+ template <loco::DataType DT> const typename loco::DataTypeImpl<DT>::Type &at(uint32_t n) const;
+ template <loco::DataType DT> typename loco::DataTypeImpl<DT>::Type &at(uint32_t n);
+
+private:
+ std::vector<uint8_t> _data;
+};
+
+/**
+ * @brief CONV_2D in TensorFlow Lite
+ */
+class TFLConv2D final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::CONV_2D>>,
+ public TFLNodeMixin<TFLNodeTrait::FusedActFunc>,
+ public TFLNodeMixin<TFLNodeTrait::Bias>
+{
+public:
+ loco::Node *input(void) const { return at(0)->node(); }
+ void input(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *filter(void) const { return at(1)->node(); }
+ void filter(loco::Node *node) { at(1)->node(node); }
+
+ loco::Node *bias(void) const override { return at(2)->node(); }
+ void bias(loco::Node *node) override { at(2)->node(node); }
+
+public:
+ Padding padding() const { return _padding; }
+ void padding(Padding padding) { _padding = padding; }
+
+ const Stride *stride(void) const { return &_stride; }
+ Stride *stride(void) { return &_stride; }
+
+private:
+ Padding _padding = Padding::UNDEFINED;
+ Stride _stride;
+};
+
+/**
+ * @brief DEPTHWISE_CONV_2D in TensorFlow Lite
+ */
+class TFLDepthwiseConv2D final
+ : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::DEPTHWISE_CONV_2D>>,
+ public TFLNodeMixin<TFLNodeTrait::FusedActFunc>,
+ public TFLNodeMixin<TFLNodeTrait::Bias>
+{
+public:
+ loco::Node *input(void) const { return at(0)->node(); }
+ void input(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *filter(void) const { return at(1)->node(); }
+ void filter(loco::Node *node) { at(1)->node(node); }
+
+ loco::Node *bias(void) const override { return at(2)->node(); }
+ void bias(loco::Node *node) override { at(2)->node(node); }
+
+public:
+ Padding padding() const { return _padding; }
+ void padding(Padding padding) { _padding = padding; }
+
+ const Stride *stride(void) const { return &_stride; }
+ Stride *stride(void) { return &_stride; }
+
+ int32_t depthMultiplier(void) const { return _depth_multiplier; }
+ void depthMultiplier(int32_t arg) { _depth_multiplier = arg; }
+
+private:
+ Padding _padding = Padding::UNDEFINED;
+ Stride _stride;
+ int32_t _depth_multiplier = 0;
+};
+
+/**
+ * @brief DIV in TensorFlow Lite
+ */
+class TFLDiv final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::DIV>>,
+ public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
+{
+public:
+ TFLDiv() = default;
+
+public:
+ loco::Node *x(void) const { return at(0)->node(); }
+ void x(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *y(void) const { return at(1)->node(); }
+ void y(loco::Node *node) { at(1)->node(node); }
+};
+
+/**
+ * @brief FULLY_CONNECTED in TensorFlow Lite
+ */
+class TFLFullyConnected final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::FULLY_CONNECTED>>,
+ public TFLNodeMixin<TFLNodeTrait::FusedActFunc>,
+ public TFLNodeMixin<TFLNodeTrait::Bias>
+{
+public:
+ loco::Node *input(void) const { return at(0)->node(); }
+ void input(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *weights(void) const { return at(1)->node(); }
+ void weights(loco::Node *node) { at(1)->node(node); }
+
+ loco::Node *bias(void) const override { return at(2)->node(); }
+ void bias(loco::Node *node) override { at(2)->node(node); }
+};
+
+/**
+ * @brief MAXIMUM in TensorFlow Lite
+ */
+class TFLMaximum final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MAXIMUM>>
+{
+public:
+ loco::Node *x(void) const { return at(0)->node(); }
+ void x(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *y(void) const { return at(1)->node(); }
+ void y(loco::Node *node) { at(1)->node(node); }
+};
+
+/**
+ * @brief MAX_POOL_2D in TensorFlow Lite
+ */
+class TFLMaxPool2D final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::MAX_POOL_2D>>,
+ public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
+{
+public:
+ TFLMaxPool2D() : _padding(Padding::UNDEFINED) { /* empty */}
+
+public:
+ loco::Node *value(void) const { return at(0)->node(); }
+ void value(loco::Node *node) { at(0)->node(node); }
+
+ Padding padding() const { return _padding; }
+ void padding(Padding padding) { _padding = padding; }
+
+ const Filter *filter(void) const { return &_filter; }
+ Filter *filter(void) { return &_filter; }
+
+ const Stride *stride(void) const { return &_stride; }
+ Stride *stride(void) { return &_stride; }
+
+private:
+ Padding _padding;
+ Stride _stride;
+ Filter _filter;
+};
+
+class TFLMean final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MEAN>>
+{
+public:
+ loco::Node *input(void) const { return at(0)->node(); }
+ void input(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *reduction_indices(void) const { return at(1)->node(); }
+ void reduction_indices(loco::Node *node) { at(1)->node(node); }
+
+public:
+ bool keep_dims(void) const { return _keep_dims; }
+ void keep_dims(bool keep_dims) { _keep_dims = keep_dims; }
+
+private:
+ bool _keep_dims = false;
+};
+
+/**
+ * @brief MUL in TensorFlow Lite
+ */
+class TFLMul final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MUL>>,
+ public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
+{
+public:
+ loco::Node *x(void) const { return at(0)->node(); }
+ void x(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *y(void) const { return at(1)->node(); }
+ void y(loco::Node *node) { at(1)->node(node); }
+};
+
+class TFLRelu final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RELU>>
+{
+public:
+ TFLRelu() = default;
+
+public:
+ loco::Node *features(void) const { return at(0)->node(); }
+ void features(loco::Node *node) { at(0)->node(node); }
+};
+
+class TFLRelu6 final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RELU6>>
+{
+public:
+ TFLRelu6() = default;
+
+public:
+ loco::Node *features(void) const { return at(0)->node(); }
+ void features(loco::Node *node) { at(0)->node(node); }
+};
+
+class TFLReshape final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::RESHAPE>>
+{
+public:
+ TFLReshape() = default;
+
+public:
+ loco::Node *tensor(void) const { return at(0)->node(); }
+ void tensor(loco::Node *node) { at(0)->node(node); }
+
+ // TODO Make this input optional. That is, loco system does not emit error
+ // with this input being null
+ loco::Node *shape(void) const { return at(1)->node(); }
+ void shape(loco::Node *node) { at(1)->node(node); }
+
+public:
+ class Shape
+ {
+ public:
+ uint32_t rank(void) const { return _shape.size(); }
+ void rank(uint32_t rank) { _shape.resize(rank); }
+
+ int32_t dim(uint32_t n) const { return _shape.at(n); }
+ int32_t &dim(uint32_t n) { return _shape.at(n); }
+
+ private:
+ std::vector<int32_t> _shape;
+ };
+
+ const Shape *newShape(void) const { return &_new_shape; }
+ Shape *newShape(void) { return &_new_shape; }
+
+private:
+ Shape _new_shape;
+};
+
+/**
+ * @brief Set both TFLReshape's 2nd input as TFLConst, and newShape attribute
+ * with same value
+ * @note Shape inference for TFLReshape forces them to be same
+ * TODO find better place for this helper
+ */
+void set_new_shape(locoex::TFLReshape *node, int32_t *base, uint32_t size);
+
+class TFLRsqrt final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RSQRT>>
+{
+public:
+ TFLRsqrt() = default;
+
+public:
+ loco::Node *x(void) const { return at(0)->node(); }
+ void x(loco::Node *node) { at(0)->node(node); }
+};
+
+// TODO TFLSoftmax
+
+class TFLSqrt final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::SQRT>>
+{
+public:
+ TFLSqrt() = default;
+
+public:
+ loco::Node *x(void) const { return at(0)->node(); }
+ void x(loco::Node *node) { at(0)->node(node); }
+};
+
+class TFLSquaredDifference final
+ : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::SQUARED_DIFFERENCE>>
+{
+public:
+ TFLSquaredDifference() = default;
+
+public:
+ loco::Node *x(void) const { return at(0)->node(); }
+ void x(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *y(void) const { return at(1)->node(); }
+ void y(loco::Node *node) { at(1)->node(node); }
+};
+
+/**
+ * @brief SUB in TensorFlow Lite
+ */
+class TFLSub final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::SUB>>,
+ public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
+{
+public:
+ TFLSub() = default;
+
+public:
+ loco::Node *x(void) const { return at(0)->node(); }
+ void x(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *y(void) const { return at(1)->node(); }
+ void y(loco::Node *node) { at(1)->node(node); }
+};
+
+// TODO TFLTanh
+
+/**
+ * @brief TRANSPOSE in TensorFlow Lite
+ */
+class TFLTranspose final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::TRANSPOSE>>
+{
+public:
+ TFLTranspose() = default;
+
+public:
+ /// @brief Get the input node to transpose
+ loco::Node *a(void) const { return at(0)->node(); }
+
+ /// @brief Set the input node to transpose
+ void a(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *perm(void) const { return at(1)->node(); }
+ void perm(loco::Node *node) { at(1)->node(node); }
+};
+
+/**
+ * @brief TRANSPOSE_CONV in TensorFlow Lite
+ *
+ * @note Argument node function names are from TensorFlow. So refering 'in' and
+ * 'out' acutally means 'out' and 'in' of the this node.
+ */
+class TFLTransposeConv final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::TRANSPOSE_CONV>>
+{
+public:
+ loco::Node *inputSizes(void) const { return at(0)->node(); }
+ void inputSizes(Node *node) { at(0)->node(node); }
+
+ loco::Node *filter(void) const { return at(1)->node(); }
+ void filter(Node *node) { at(1)->node(node); }
+
+ loco::Node *outBackprop(void) const { return at(2)->node(); }
+ void outBackprop(Node *node) { at(2)->node(node); }
+
+public:
+ const Padding &padding(void) const { return _padding; }
+ void padding(const Padding &padding) { _padding = padding; }
+
+ const Stride *stride(void) const { return &_stride; }
+ Stride *stride(void) { return &_stride; }
+
+private:
+ Padding _padding;
+ Stride _stride;
+};
+
+// TODO define more children of TFLNode
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_TFLNODES_H__
diff --git a/compiler/exo/src/Dialect/IR/TFLNodes.lst b/compiler/exo/src/Dialect/IR/TFLNodes.lst
new file mode 100644
index 000000000..225e2be3b
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLNodes.lst
@@ -0,0 +1,30 @@
+#ifndef TFL_NODE
+#error "Define TFL_NODE"
+#endif // TFL_NODE
+
+//
+// PLEASE SORT NODE DECLS IN ALPHABETICAL ORDER
+//
+TFL_NODE(ADD, locoex::TFLAdd)
+TFL_NODE(AVERAGE_POOL_2D, locoex::TFLAveragePool2D)
+TFL_NODE(CONCATENATION, locoex::TFLConcatenation)
+TFL_NODE(CONST, locoex::TFLConst)
+TFL_NODE(CONV_2D, locoex::TFLConv2D)
+TFL_NODE(DEPTHWISE_CONV_2D, locoex::TFLDepthwiseConv2D)
+TFL_NODE(DIV, locoex::TFLDiv)
+TFL_NODE(FULLY_CONNECTED, locoex::TFLFullyConnected)
+TFL_NODE(MAXIMUM, locoex::TFLMaximum)
+TFL_NODE(MAX_POOL_2D, locoex::TFLMaxPool2D)
+TFL_NODE(MEAN, locoex::TFLMean)
+TFL_NODE(MUL, locoex::TFLMul)
+TFL_NODE(RELU, locoex::TFLRelu)
+TFL_NODE(RELU6, locoex::TFLRelu6)
+TFL_NODE(RESHAPE, locoex::TFLReshape)
+TFL_NODE(RSQRT, locoex::TFLRsqrt)
+// TODO TFLSoftmax
+TFL_NODE(SQRT, locoex::TFLSqrt)
+TFL_NODE(SQUARED_DIFFERENCE, locoex::TFLSquaredDifference)
+TFL_NODE(SUB, locoex::TFLSub)
+// TODO TFLTanh
+TFL_NODE(TRANSPOSE, locoex::TFLTranspose)
+TFL_NODE(TRANSPOSE_CONV, locoex::TFLTransposeConv)
diff --git a/compiler/exo/src/Dialect/IR/TFLNodes.test.cpp b/compiler/exo/src/Dialect/IR/TFLNodes.test.cpp
new file mode 100644
index 000000000..09c5c83a0
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLNodes.test.cpp
@@ -0,0 +1,159 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLNodes.h"
+
+#include "TFLDialect.h"
+#include "TFLOpcode.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFLAddTest, constructor)
+{
+ locoex::TFLAdd add_node;
+
+ ASSERT_EQ(add_node.dialect(), locoex::TFLDialect::get());
+ ASSERT_EQ(add_node.opcode(), locoex::TFLOpcode::ADD);
+
+ ASSERT_EQ(add_node.x(), nullptr);
+ ASSERT_EQ(add_node.y(), nullptr);
+}
+
+// TODO TFLAveragePool2D
+
+TEST(TFLConcatTest, constructor)
+{
+ locoex::TFLConcatenation concat_node(3);
+
+ ASSERT_EQ(concat_node.dialect(), locoex::TFLDialect::get());
+ ASSERT_EQ(concat_node.opcode(), locoex::TFLOpcode::CONCATENATION);
+
+ ASSERT_EQ(concat_node.numValues(), 3);
+ ASSERT_EQ(concat_node.values(0), nullptr);
+ ASSERT_EQ(concat_node.values(1), nullptr);
+ ASSERT_EQ(concat_node.values(2), nullptr);
+ ASSERT_EQ(concat_node.fusedActivationFunction(), locoex::FusedActFunc::UNDEFINED);
+}
+
+// TODO TFLConv2D
+
+TEST(TFLDepthwiseConv2DTest, constructor)
+{
+ locoex::TFLDepthwiseConv2D dw_conv2d_node;
+
+ ASSERT_EQ(dw_conv2d_node.dialect(), locoex::TFLDialect::get());
+ ASSERT_EQ(dw_conv2d_node.opcode(), locoex::TFLOpcode::DEPTHWISE_CONV_2D);
+
+ ASSERT_EQ(dw_conv2d_node.input(), nullptr);
+ ASSERT_EQ(dw_conv2d_node.filter(), nullptr);
+ ASSERT_EQ(dw_conv2d_node.bias(), nullptr);
+ ASSERT_EQ(dw_conv2d_node.padding(), locoex::Padding::UNDEFINED);
+ ASSERT_EQ(dw_conv2d_node.stride()->h(), 1);
+ ASSERT_EQ(dw_conv2d_node.stride()->w(), 1);
+ ASSERT_EQ(dw_conv2d_node.depthMultiplier(), 0);
+ ASSERT_EQ(dw_conv2d_node.fusedActivationFunction(), locoex::FusedActFunc::UNDEFINED);
+}
+
+TEST(TFLDivTest, constructor)
+{
+ locoex::TFLDiv div_node;
+
+ ASSERT_EQ(div_node.dialect(), locoex::TFLDialect::get());
+ ASSERT_EQ(div_node.opcode(), locoex::TFLOpcode::DIV);
+
+ ASSERT_EQ(div_node.x(), nullptr);
+ ASSERT_EQ(div_node.y(), nullptr);
+}
+
+// TODO TFLMaxPool2D
+
+TEST(TFLMulTest, constructor)
+{
+ locoex::TFLMul mul_node;
+
+ ASSERT_EQ(mul_node.dialect(), locoex::TFLDialect::get());
+ ASSERT_EQ(mul_node.opcode(), locoex::TFLOpcode::MUL);
+
+ ASSERT_EQ(mul_node.x(), nullptr);
+ ASSERT_EQ(mul_node.y(), nullptr);
+}
+
+TEST(TFLReluTest, constructor)
+{
+ locoex::TFLRelu relu_node;
+
+ ASSERT_EQ(relu_node.dialect(), locoex::TFLDialect::get());
+ ASSERT_EQ(relu_node.opcode(), locoex::TFLOpcode::RELU);
+
+ ASSERT_EQ(relu_node.features(), nullptr);
+}
+
+// TODO TFLRelu6
+
+TEST(TFLReshapeTest, constructor)
+{
+ locoex::TFLReshape reshape;
+
+ ASSERT_EQ(reshape.dialect(), locoex::TFLDialect::get());
+ ASSERT_EQ(reshape.opcode(), locoex::TFLOpcode::RESHAPE);
+
+ ASSERT_EQ(reshape.tensor(), nullptr);
+ ASSERT_EQ(reshape.shape(), nullptr);
+ ASSERT_EQ(reshape.newShape()->rank(), 0);
+}
+
+TEST(TFLReshapeTest, alloc_new_shape)
+{
+ locoex::TFLReshape reshape;
+
+ reshape.newShape()->rank(2);
+ ASSERT_EQ(reshape.newShape()->rank(), 2);
+
+ reshape.newShape()->dim(0) = 0;
+ reshape.newShape()->dim(1) = 1;
+
+ auto &const_reshape = const_cast<const locoex::TFLReshape &>(reshape);
+ ASSERT_EQ(const_reshape.newShape()->dim(0), 0);
+ ASSERT_EQ(const_reshape.newShape()->dim(1), 1);
+}
+
+// TODO TFLSoftmax
+
+// TODO TFLSqrt
+
+TEST(TFLSubTest, constructor)
+{
+ locoex::TFLSub sub_node;
+
+ ASSERT_EQ(sub_node.dialect(), locoex::TFLDialect::get());
+ ASSERT_EQ(sub_node.opcode(), locoex::TFLOpcode::SUB);
+
+ ASSERT_EQ(sub_node.x(), nullptr);
+ ASSERT_EQ(sub_node.y(), nullptr);
+}
+
+// TODO TFLTanh
+
+TEST(TFLTransposeTest, constructor)
+{
+ locoex::TFLTranspose tr_node;
+
+ ASSERT_EQ(tr_node.dialect(), locoex::TFLDialect::get());
+ ASSERT_EQ(tr_node.opcode(), locoex::TFLOpcode::TRANSPOSE);
+
+ ASSERT_EQ(tr_node.a(), nullptr);
+ ASSERT_EQ(tr_node.perm(), nullptr);
+}
diff --git a/compiler/exo/src/Dialect/IR/TFLOpcode.h b/compiler/exo/src/Dialect/IR/TFLOpcode.h
new file mode 100644
index 000000000..0c0ab64bd
--- /dev/null
+++ b/compiler/exo/src/Dialect/IR/TFLOpcode.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_IR_TFLOPCODE_H__
+#define __LOCOEX_IR_TFLOPCODE_H__
+
+namespace locoex
+{
+
+enum class TFLOpcode
+{
+#define TFL_NODE(OPCODE, CLASS) OPCODE,
+#include "TFLNodes.lst"
+#undef TFL_NODE
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_IR_TFLOPCODE_H__
diff --git a/compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.cpp b/compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.cpp
new file mode 100644
index 000000000..2e71aa000
--- /dev/null
+++ b/compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.cpp
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleShapeInferenceRule.h"
+
+#include "Dialect/IR/CircleNodes.h"
+#include "Dialect/IR/CircleDialect.h"
+#include "Dialect/IR/CircleNodeVisitor.h"
+
+#include "Check.h"
+
+#include <cassert>
+
+namespace
+{
+
+/**
+ * @brief Class to infer the shape of CircleNode
+ *
+ * @note All CircleNode's inputs and outputs are always loco::Domain::Tensor
+ */
+class ShapeInferenceAlgorithm final : public locoex::CircleNodeVisitor<loco::NodeShape>
+{
+public:
+ loco::NodeShape visit(const locoex::CircleInstanceNorm *node) final
+ {
+ auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+
+ return loco::NodeShape{input_shape};
+ }
+};
+
+} // namespace
+
+namespace locoex
+{
+
+bool CircleShapeInferenceRule::recognize(const loco::Dialect *d) const
+{
+ return CircleDialect::get() == d;
+}
+
+bool CircleShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
+{
+ assert(node->dialect() == CircleDialect::get());
+ assert(dynamic_cast<const CircleNode *>(node) != nullptr);
+
+ ShapeInferenceAlgorithm alg;
+ shape = dynamic_cast<const CircleNode *>(node)->accept(&alg);
+
+ return true;
+}
+
+} // namespace locoex
diff --git a/compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.h b/compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.h
new file mode 100644
index 000000000..92f23c9dd
--- /dev/null
+++ b/compiler/exo/src/Dialect/Service/CircleShapeInferenceRule.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_SERVICE_CIRCLESHAPE_INFERENCE_RULE_H__
+#define __LOCOEX_SERVICE_CIRCLESHAPE_INFERENCE_RULE_H__
+
+#include <loco/Service/ShapeInference.h>
+
+namespace locoex
+{
+
+struct CircleShapeInferenceRule final : public loco::ShapeInferenceRule
+{
+ bool recognize(const loco::Dialect *) const final;
+ bool infer(const loco::Node *, loco::NodeShape &) const final;
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_SERVICE_CIRCLESHAPE_INFERENCE_RULE_H__
diff --git a/compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.cpp b/compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.cpp
new file mode 100644
index 000000000..6bc95a1b5
--- /dev/null
+++ b/compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.cpp
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleTypeInferenceRule.h"
+
+#include "Dialect/IR/CircleDialect.h"
+#include "Dialect/IR/CircleNodeVisitor.h"
+#include "Dialect/IR/CircleNodes.h"
+
+#include <cassert>
+
+namespace
+{
+
+struct TypeInferenceAlgorithm final : public locoex::CircleNodeVisitor<loco::DataType>
+{
+ loco::DataType visit(const locoex::CircleInstanceNorm *node) final
+ {
+ return loco::dtype_get(node->input());
+ }
+};
+
+} // namespace
+
+namespace locoex
+{
+
+bool CircleTypeInferenceRule::recognize(const loco::Dialect *d) const
+{
+ return CircleDialect::get() == d;
+}
+
+bool CircleTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const
+{
+ assert(node->dialect() == CircleDialect::get());
+
+ TypeInferenceAlgorithm alg;
+
+ dtype = dynamic_cast<const CircleNode *>(node)->accept(&alg);
+ assert(dtype != loco::DataType::Unknown);
+
+ return true;
+}
+
+} // namespace locoex
diff --git a/compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.h b/compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.h
new file mode 100644
index 000000000..c073dfc54
--- /dev/null
+++ b/compiler/exo/src/Dialect/Service/CircleTypeInferenceRule.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_SERVICE_CIRCLETYPE_INFERENCE_RULE_H__
+#define __LOCOEX_SERVICE_CIRCLETYPE_INFERENCE_RULE_H__
+
+#include <loco/Service/TypeInference.h>
+
+namespace locoex
+{
+
+/**
+ * @brief Type Inference Rule for CircleDialect
+ */
+struct CircleTypeInferenceRule final : public loco::TypeInferenceRule
+{
+ bool recognize(const loco::Dialect *) const final;
+ bool infer(const loco::Node *, loco::DataType &) const final;
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_SERVICE_CIRCLETYPE_INFERENCE_RULE_H__
diff --git a/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp
new file mode 100644
index 000000000..f4bb10364
--- /dev/null
+++ b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp
@@ -0,0 +1,627 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLShapeInferenceRule.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+#include "Check.h"
+
+#include <oops/InternalExn.h>
+
+#include <algorithm>
+#include <cassert>
+#include <stdexcept>
+
+namespace
+{
+
+// Call this for TFLAvgPool2D and TFLMaxPool2D only
+template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node)
+{
+ EXO_ASSERT(loco::shape_known(node->value()), "Shape must be known");
+
+ auto ifm_shape = loco::shape_get(node->value()).template as<loco::TensorShape>();
+ assert(ifm_shape.rank() == 4);
+
+ uint32_t input_height = ifm_shape.dim(1).value();
+ uint32_t input_width = ifm_shape.dim(2).value();
+ uint32_t stride_height = node->stride()->h();
+ uint32_t stride_width = node->stride()->w();
+ uint32_t window_height = node->filter()->h();
+ uint32_t window_width = node->filter()->w();
+ uint32_t dilation_height = 1; // dilation for TFLAvgPool2D and TFLMaxPool2D is 1
+ uint32_t dilation_width = 1;
+ uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
+ uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
+
+ uint32_t output_height = 0;
+ uint32_t output_width = 0;
+
+ if (node->padding() == locoex::Padding::VALID)
+ {
+ output_height = (input_height + stride_height - effective_window_height) / stride_height;
+ output_width = (input_width + stride_width - effective_window_width) / stride_width;
+ }
+ else if (node->padding() == locoex::Padding::SAME)
+ {
+ output_height = (input_height + stride_height - 1) / stride_height;
+ output_width = (input_width + stride_width - 1) / stride_width;
+ }
+ else
+ EXO_ASSERT(false, "Wrong padding type");
+
+ loco::TensorShape ofm_shape;
+ ofm_shape.rank(4);
+ ofm_shape.dim(0) = ifm_shape.dim(0);
+ ofm_shape.dim(1) = output_height;
+ ofm_shape.dim(2) = output_width;
+ ofm_shape.dim(3) = ifm_shape.dim(3);
+
+ return loco::NodeShape{ofm_shape};
+}
+
+/**
+ * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
+ *
+ * HOW TO USE:
+ *
+ * auto expanded_tensor_shape = expand(tensor_shape).to(N);
+ */
+class TensorShapeExpander
+{
+public:
+ TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
+ {
+ // DO NOTHING
+ }
+
+public:
+ loco::TensorShape to(uint32_t output_rank)
+ {
+ auto const &input_shape = _shape;
+ uint32_t const input_rank = input_shape.rank();
+
+ assert(input_rank <= output_rank && "Cannot shrink rank");
+ uint32_t const axis_shift = output_rank - input_rank;
+
+ loco::TensorShape output_shape;
+
+ output_shape.rank(output_rank);
+ for (uint32_t axis = 0; axis < output_rank; ++axis)
+ {
+ output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
+ }
+
+ return output_shape;
+ }
+
+private:
+ const loco::TensorShape _shape;
+};
+
+/**
+ * @breif Expand shape x and y to same rank by align right and filling with 1
+ */
+void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
+{
+ auto x_rank = x.rank();
+ auto y_rank = y.rank();
+
+ if (x_rank == y_rank)
+ return;
+
+ TensorShapeExpander x_exp(x);
+ TensorShapeExpander y_exp(y);
+
+ auto xy_rank = std::max(x_rank, y_rank);
+
+ x = x_rank > y_rank ? x : x_exp.to(xy_rank);
+ y = y_rank > x_rank ? y : y_exp.to(xy_rank);
+}
+
+/**
+ * @breif Returns shape of expanded dimension of input x and y having same rank
+ */
+loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
+{
+ assert(x.rank() == y.rank());
+
+ auto rank = x.rank();
+
+ loco::TensorShape output_shape;
+
+ output_shape.rank(rank);
+ for (uint32_t axis = 0; axis < rank; ++axis)
+ {
+ assert(x.dim(axis).known() && y.dim(axis).known());
+
+ auto x_dim = x.dim(axis).value();
+ auto y_dim = y.dim(axis).value();
+
+ // each dimension of x and y should be same or one must be 1 if different
+ if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
+ INTERNAL_EXN("Cannot produce expand_dimension of two shapes");
+
+ output_shape.dim(axis) = std::max(x_dim, y_dim);
+ }
+
+ return output_shape;
+}
+
+loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
+{
+ auto x_match = x;
+ auto y_match = y;
+
+ expand_rank(x_match, y_match);
+
+ auto output_shape = expand_dimension(x_match, y_match);
+
+ return output_shape;
+}
+
+/**
+ * @brief Class to infer the shape of TFLNode
+ *
+ * @note All TFLNode's inputs and outputs are always loco::Domain::Tensor
+ */
+class ShapeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::NodeShape>
+{
+public:
+ loco::NodeShape visit(const locoex::TFLAdd *node) final
+ {
+ auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
+ auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
+
+ auto output_shape = broadcast_shape(x_shape, y_shape);
+
+ return loco::NodeShape{output_shape};
+ }
+
+ loco::NodeShape visit(const locoex::TFLAveragePool2D *node) final
+ {
+ return infer_pool_2d_shape(node);
+ }
+
+ loco::NodeShape visit(const locoex::TFLConcatenation *node) final
+ {
+ // TODO Support when TFLConcatenation has 0 input
+ assert(node->numValues() > 0);
+
+ auto axis = node->axis();
+ auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>();
+
+ loco::TensorShape output_shape;
+
+ output_shape.rank(first_shape.rank());
+ for (uint32_t i = 0; i < output_shape.rank(); ++i)
+ output_shape.dim(i) = first_shape.dim(i);
+
+ for (uint32_t i = 1; i < node->numValues(); ++i)
+ {
+ auto input_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>();
+
+ for (uint32_t j = 0; j < output_shape.rank(); ++j)
+ {
+ if (j == axis)
+ output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value();
+ else
+ assert(output_shape.dim(j) == input_shape.dim(j));
+ }
+ }
+
+ return loco::NodeShape{output_shape};
+ }
+
+ loco::NodeShape visit(const locoex::TFLConst *node) final
+ {
+ loco::TensorShape shape;
+
+ shape.rank(node->rank());
+ for (uint32_t axis = 0; axis < node->rank(); axis++)
+ shape.dim(axis) = node->dim(axis);
+
+ return loco::NodeShape{shape};
+ }
+
+ loco::NodeShape visit(const locoex::TFLConv2D *node) final
+ {
+ auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
+ auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
+
+ assert(ifm_shape.rank() == 4);
+ assert(ker_shape.rank() == 4);
+ assert(ifm_shape.dim(3) == ker_shape.dim(3));
+
+ uint32_t input_height = ifm_shape.dim(1).value();
+ uint32_t input_width = ifm_shape.dim(2).value();
+ uint32_t stride_height = node->stride()->h();
+ uint32_t stride_width = node->stride()->w();
+ uint32_t ker_height = ker_shape.dim(1).value();
+ uint32_t ker_width = ker_shape.dim(2).value();
+ uint32_t dilation_height = 1;
+ uint32_t dilation_width = 1;
+ uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
+ uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
+
+ uint32_t output_height = 0;
+ uint32_t output_width = 0;
+
+ if (node->padding() == locoex::Padding::VALID)
+ {
+ output_height = (input_height + stride_height - effective_ker_height) / stride_height;
+ output_width = (input_width + stride_width - effective_ker_width) / stride_width;
+ }
+ else if (node->padding() == locoex::Padding::SAME)
+ {
+ output_height = (input_height + stride_height - 1) / stride_height;
+ output_width = (input_width + stride_width - 1) / stride_width;
+ }
+ else
+ EXO_ASSERT(false, "Wrong padding type");
+
+ loco::TensorShape ofm_shape;
+ ofm_shape.rank(4);
+ ofm_shape.dim(0) = ifm_shape.dim(0);
+ ofm_shape.dim(1) = output_height;
+ ofm_shape.dim(2) = output_width;
+ ofm_shape.dim(3) = ker_shape.dim(0);
+
+ return loco::NodeShape{ofm_shape};
+ }
+
+ loco::NodeShape visit(const locoex::TFLDepthwiseConv2D *node) final
+ {
+ auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
+ auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM
+
+ assert(ifm_shape.rank() == 4);
+ assert(ker_shape.rank() == 4);
+ assert(ker_shape.dim(0).value() == 1);
+
+ uint32_t input_height = ifm_shape.dim(1).value();
+ uint32_t input_width = ifm_shape.dim(2).value();
+ uint32_t stride_height = node->stride()->h();
+ uint32_t stride_width = node->stride()->w();
+ uint32_t ker_height = ker_shape.dim(1).value();
+ uint32_t ker_width = ker_shape.dim(2).value();
+ uint32_t dilation_height = 1;
+ uint32_t dilation_width = 1;
+ uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
+ uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
+
+ uint32_t output_height = 0;
+ uint32_t output_width = 0;
+
+ if (node->padding() == locoex::Padding::VALID)
+ {
+ output_height = (input_height + stride_height - effective_ker_height) / stride_height;
+ output_width = (input_width + stride_width - effective_ker_width) / stride_width;
+ }
+ else if (node->padding() == locoex::Padding::SAME)
+ {
+ output_height = (input_height + stride_height - 1) / stride_height;
+ output_width = (input_width + stride_width - 1) / stride_width;
+ }
+ else
+ EXO_ASSERT(false, "Wrong padding type");
+
+ loco::TensorShape ofm_shape;
+ ofm_shape.rank(4);
+ ofm_shape.dim(0) = ifm_shape.dim(0);
+ ofm_shape.dim(1) = output_height;
+ ofm_shape.dim(2) = output_width;
+ ofm_shape.dim(3) = ker_shape.dim(3);
+
+ return loco::NodeShape{ofm_shape};
+ }
+
+ loco::NodeShape visit(const locoex::TFLDiv *node) final
+ {
+ auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
+ auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
+
+ auto output_shape = broadcast_shape(x_shape, y_shape);
+
+ return loco::NodeShape{output_shape};
+ }
+
+ loco::NodeShape visit(const locoex::TFLFullyConnected *node) final
+ {
+ auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto weights_shape = loco::shape_get(node->weights()).as<loco::TensorShape>();
+
+ // Checking shape capability for multiplication
+ EXO_ASSERT(input_shape.rank() == 2, "NYI for input shape rank > 2");
+ EXO_ASSERT(weights_shape.rank() == 2, "Incompatible weights rank for fully connected");
+ EXO_ASSERT(input_shape.dim(1) == weights_shape.dim(1),
+ "Incompatible shapes for fully connected");
+
+ loco::TensorShape out_shape;
+ out_shape.rank(2);
+
+ out_shape.dim(0) = input_shape.dim(0);
+ out_shape.dim(1) = weights_shape.dim(0);
+
+ return loco::NodeShape{out_shape};
+ }
+
+ loco::NodeShape visit(const locoex::TFLMaximum *node) final
+ {
+ auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
+ auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
+
+ auto output_shape = broadcast_shape(x_shape, y_shape);
+
+ return loco::NodeShape{output_shape};
+ }
+
+ loco::NodeShape visit(const locoex::TFLMaxPool2D *node) final
+ {
+ return infer_pool_2d_shape(node);
+ }
+
+ loco::NodeShape visit(const locoex::TFLMean *node) final
+ {
+ const loco::DataType S32 = loco::DataType::S32;
+
+ auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto reduction_indices = dynamic_cast<locoex::TFLConst *>(node->reduction_indices());
+
+ { // Exceptions
+ // TODO support non-const case
+ EXO_ASSERT(reduction_indices, "Only support constant reduction_indices");
+ // TODO support other data type
+ EXO_ASSERT(reduction_indices->dtype() == S32, "Only support int 32");
+ }
+
+ std::vector<int32_t> reduction_values;
+
+ for (uint32_t i = 0; i < reduction_indices->size<S32>(); ++i)
+ {
+ int32_t axis = reduction_indices->at<S32>(i);
+ if (axis < 0)
+ axis += input_shape.rank();
+ if (not(0 <= axis and axis < static_cast<int32_t>(input_shape.rank())))
+ INTERNAL_EXN_V("Invalid reduction axis for MEAN", oops::to_uint32(axis));
+ reduction_values.push_back(axis);
+ }
+
+ loco::TensorShape output_shape;
+
+ if (node->keep_dims())
+ {
+ output_shape.rank(input_shape.rank());
+ for (uint32_t i = 0; i < input_shape.rank(); ++i)
+ output_shape.dim(i) = input_shape.dim(i);
+ for (uint32_t i = 0; i < reduction_values.size(); ++i)
+ output_shape.dim(reduction_values.at(i)) = 1;
+ }
+ else
+ {
+ std::vector<bool> check_reduce(input_shape.rank(), false);
+ for (uint32_t i = 0; i < reduction_values.size(); ++i)
+ check_reduce.at(reduction_values.at(i)) = true;
+
+ uint32_t reduce_cnt = 0;
+ for (uint32_t i = 0; i < check_reduce.size(); ++i)
+ if (check_reduce.at(i))
+ ++reduce_cnt;
+
+ output_shape.rank(input_shape.rank() - reduce_cnt);
+ for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
+ if (check_reduce.at(i) == false)
+ output_shape.dim(j++) = i;
+ }
+
+ return loco::NodeShape{output_shape};
+ }
+
+ loco::NodeShape visit(const locoex::TFLMul *node) final
+ {
+ auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
+ auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
+
+ auto output_shape = broadcast_shape(x_shape, y_shape);
+
+ return loco::NodeShape{output_shape};
+ }
+
+ loco::NodeShape visit(const locoex::TFLRelu *node) final
+ {
+ auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
+
+ return loco::NodeShape{input_shape};
+ }
+
+ loco::NodeShape visit(const locoex::TFLRelu6 *node) final
+ {
+ auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
+
+ return loco::NodeShape{input_shape};
+ }
+
+ /**
+ * @note TFLReshape has new shape info in two places: 2nd input and attribute.
+ * This shape inference forces both to exist, and match each other.
+ * When this condition satisfied, it return the inferred shape
+ *
+ * TODO Change this policy when not appropriate
+ */
+ loco::NodeShape visit(const locoex::TFLReshape *node) final
+ {
+ const loco::DataType S32 = loco::DataType::S32;
+
+ loco::TensorShape shape_by_input;
+ {
+ EXO_ASSERT(node->shape(), "2nd input shape() should not be nullptr");
+
+ // Only support node's shape() is TFLConst with S32
+ // TODO support other node with other types
+ auto const_shape_node = dynamic_cast<locoex::TFLConst *>(node->shape());
+ EXO_ASSERT(const_shape_node, "Only support TFLConst for shape of TFLReshape");
+ EXO_ASSERT(const_shape_node->dtype() == S32, "Only support int32 TFLConst");
+
+ if (const_shape_node->rank() != 1)
+ INTERNAL_EXN_V("Only support rank 1 TFLConst", oops::to_uint32(const_shape_node->rank()));
+
+ shape_by_input.rank(const_shape_node->dim(0).value());
+
+ for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
+ {
+ EXO_ASSERT(const_shape_node->at<S32>(axis) > 0, "Dimension should be > 0")
+ shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
+ }
+ }
+
+ loco::TensorShape shape_by_attr;
+ {
+ shape_by_attr.rank(node->newShape()->rank());
+
+ for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
+ {
+ EXO_ASSERT(node->newShape()->dim(axis) > 0, "Dimension should be > 0")
+ shape_by_attr.dim(axis) = node->newShape()->dim(axis);
+ }
+ }
+
+ EXO_ASSERT(shape_by_input == shape_by_attr,
+ "Warning: Two new shape information mismatched for TFLReshape");
+
+ return loco::NodeShape{shape_by_input};
+ }
+
+ loco::NodeShape visit(const locoex::TFLRsqrt *node) final
+ {
+ auto input_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
+
+ return loco::NodeShape{input_shape};
+ }
+
+ // TODO TFLSoftmax
+
+ loco::NodeShape visit(const locoex::TFLSqrt *node) final
+ {
+ auto input_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
+
+ return loco::NodeShape{input_shape};
+ }
+
+ loco::NodeShape visit(const locoex::TFLSquaredDifference *node) final
+ {
+ auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
+ auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
+
+ auto output_shape = broadcast_shape(x_shape, y_shape);
+
+ return loco::NodeShape{output_shape};
+ }
+
+ loco::NodeShape visit(const locoex::TFLSub *node) final
+ {
+ auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
+ auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
+
+ auto output_shape = broadcast_shape(x_shape, y_shape);
+
+ return loco::NodeShape{output_shape};
+ }
+
+ // TODO TFLTanh
+
+ /// @brief Returns output shape of transpose. Use loco::ConstGen and locoex::TFLConst for ConstT.
+ template <class ConstT>
+ loco::TensorShape output_shape_of_transpose(loco::TensorShape input_shape,
+ const ConstT *perm_node)
+ {
+ loco::TensorShape output_shape;
+ output_shape.rank(input_shape.rank());
+
+ assert(perm_node->dtype() == loco::DataType::S32);
+ assert(input_shape.rank() == perm_node->template size<loco::DataType::S32>());
+
+ for (uint32_t out_axis = 0; out_axis < output_shape.rank(); out_axis++)
+ {
+ auto new_dim = perm_node->template at<loco::DataType::S32>(out_axis);
+ output_shape.dim(new_dim) = input_shape.dim(out_axis);
+ }
+
+ return output_shape;
+ }
+
+ loco::NodeShape visit(const locoex::TFLTranspose *node) final
+ {
+ auto input_shape = loco::shape_get(node->a()).as<loco::TensorShape>();
+
+ auto canon_perm = dynamic_cast<loco::ConstGen *>(node->perm());
+ auto tfl_perm = dynamic_cast<locoex::TFLConst *>(node->perm());
+
+ if (canon_perm)
+ {
+ return loco::NodeShape{output_shape_of_transpose(input_shape, canon_perm)};
+ }
+ else if (tfl_perm)
+ {
+ return loco::NodeShape{output_shape_of_transpose(input_shape, tfl_perm)};
+ }
+ else
+ INTERNAL_EXN("perm of TFLTranspose should be either ConstGen or TFLConst");
+ }
+
+ loco::NodeShape visit(const locoex::TFLTransposeConv *node) final
+ {
+ // TransposeConv's output shape is written in its 'inputSizes' argument
+ auto input_sizes_const = dynamic_cast<locoex::TFLConst *>(node->inputSizes());
+ EXO_ASSERT(input_sizes_const, "Only support when TFLTransposeConv's inputSizes is TFLConst")
+ EXO_ASSERT(input_sizes_const->dtype() == loco::DataType::S32, "Only support S32 dtype")
+ EXO_ASSERT(input_sizes_const->rank() == 1 && input_sizes_const->dim(0).value() == 4,
+ "Only support rank 1 with 4 entries")
+
+ loco::TensorShape shape;
+
+ shape.rank(4);
+ for (uint32_t axis = 0; axis < 4; ++axis)
+ shape.dim(axis) = input_sizes_const->at<loco::DataType::S32>(axis);
+
+ return loco::NodeShape{shape};
+ }
+};
+
+} // namespace
+
+namespace locoex
+{
+
+bool TFLShapeInferenceRule::recognize(const loco::Dialect *d) const
+{
+ return TFLDialect::get() == d;
+}
+
+bool TFLShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
+{
+ assert(node->dialect() == TFLDialect::get());
+ assert(dynamic_cast<const TFLNode *>(node) != nullptr);
+
+ ShapeInferenceAlgorithm alg;
+ shape = dynamic_cast<const TFLNode *>(node)->accept(&alg);
+
+ return true;
+}
+
+} // namespace locoex
diff --git a/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.h b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.h
new file mode 100644
index 000000000..434a145cc
--- /dev/null
+++ b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__
+#define __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__
+
+#include <loco/Service/ShapeInference.h>
+
+namespace locoex
+{
+
+struct TFLShapeInferenceRule final : public loco::ShapeInferenceRule
+{
+ bool recognize(const loco::Dialect *) const final;
+ bool infer(const loco::Node *, loco::NodeShape &) const final;
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__
diff --git a/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.test.cpp b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.test.cpp
new file mode 100644
index 000000000..35c8f0b2a
--- /dev/null
+++ b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.test.cpp
@@ -0,0 +1,277 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestGraph.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/Service/TFLShapeInferenceRule.h"
+
+#include <loco.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/Service/ShapeInference.h>
+#include <loco/Service/CanonicalShapeInferenceRule.h>
+#include <loco/Service/MultiDialectShapeInferenceRule.h>
+
+#include <stdex/Memory.h>
+
+#include <gtest/gtest.h>
+
+TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu)
+{
+ // Create a simple network
+ exo::test::TestGraph graph;
+ auto tfl_node = graph.append<locoex::TFLRelu>(graph.pull);
+ graph.complete(tfl_node);
+
+ // set shape
+ {
+ graph.pull->rank(2);
+ graph.pull->dim(0) = 3;
+ graph.pull->dim(1) = 4;
+ }
+
+ // pre-check
+ ASSERT_FALSE(loco::shape_known(tfl_node));
+
+ // shape inference
+ locoex::TFLShapeInferenceRule tfl_rule;
+ loco::CanonicalShapeInferenceRule canonical_rule;
+ loco::MultiDialectShapeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(locoex::TFLDialect::get(), &tfl_rule);
+
+ loco::apply(&rules).to(graph.g.get());
+
+ // Verify
+ {
+ ASSERT_TRUE(loco::shape_known(tfl_node));
+ ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor);
+
+ auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
+ ASSERT_EQ(shape.rank(), 2);
+ ASSERT_EQ(shape.dim(0), 3);
+ ASSERT_EQ(shape.dim(1), 4);
+ }
+}
+
+// based on the case shown in
+// https://www.corvil.com/kb/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-tensorflow
+TEST(TFLShapeInferenceRuleTest, avgpool2d_valid)
+{
+ exo::test::TestGraph graph;
+ auto tfl_node = graph.append<locoex::TFLAveragePool2D>(graph.pull);
+ graph.complete();
+
+ auto pull = graph.pull;
+ {
+ pull->shape({1, 4, 3, 1});
+ }
+ // setting TFLAveragePool2D
+ {
+ tfl_node->filter()->h(2);
+ tfl_node->filter()->w(2);
+ tfl_node->stride()->h(2);
+ tfl_node->stride()->w(2);
+ tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE);
+ tfl_node->padding(locoex::Padding::VALID);
+ }
+ ASSERT_FALSE(loco::shape_known(tfl_node));
+
+ // shape inference
+ locoex::TFLShapeInferenceRule tfl_rule;
+ loco::CanonicalShapeInferenceRule canonical_rule;
+ loco::MultiDialectShapeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(locoex::TFLDialect::get(), &tfl_rule);
+
+ loco::apply(&rules).to(graph.g.get());
+
+ // Verify
+ {
+ ASSERT_TRUE(loco::shape_known(tfl_node));
+ ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor);
+
+ auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
+ ASSERT_EQ(shape.rank(), 4);
+ ASSERT_EQ(shape.dim(0).value(), 1);
+ ASSERT_EQ(shape.dim(1).value(), 2);
+ ASSERT_EQ(shape.dim(2).value(), 1);
+ ASSERT_EQ(shape.dim(3).value(), 1);
+ }
+}
+
+TEST(TFLShapeInferenceRuleTest, avgpool2d_same)
+{
+ exo::test::TestGraph graph;
+ auto tfl_node = graph.append<locoex::TFLAveragePool2D>(graph.pull);
+ graph.complete();
+
+ auto pull = graph.pull;
+ {
+ pull->shape({1, 4, 3, 1});
+ }
+
+ // setting TFLAveragePool2D
+ {
+ tfl_node->filter()->h(2);
+ tfl_node->filter()->w(2);
+ tfl_node->stride()->h(2);
+ tfl_node->stride()->w(2);
+ tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE);
+ tfl_node->padding(locoex::Padding::SAME);
+ }
+
+ ASSERT_FALSE(loco::shape_known(tfl_node));
+
+ // shape inference
+ locoex::TFLShapeInferenceRule tfl_rule;
+ loco::CanonicalShapeInferenceRule canonical_rule;
+ loco::MultiDialectShapeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(locoex::TFLDialect::get(), &tfl_rule);
+
+ loco::apply(&rules).to(graph.g.get());
+
+ // Verify
+ {
+ ASSERT_TRUE(loco::shape_known(tfl_node));
+ ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor);
+
+ auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
+ ASSERT_EQ(shape.rank(), 4);
+ ASSERT_EQ(shape.dim(0).value(), 1);
+ ASSERT_EQ(shape.dim(1).value(), 2);
+ ASSERT_EQ(shape.dim(2).value(), 2);
+ ASSERT_EQ(shape.dim(3).value(), 1);
+ }
+}
+
+/**
+ * @note Function to test: Shape inference of two different input shapes
+ *
+ * Rank expansion to higher input side
+ * x(2,1,5) + y(3,5) --> x(2,1,5) + y(1,3,5)
+ * Do output shape inference like numpy
+ * x(2,1,5) + y(1,3,5) --> output(2,3,5)
+ * For each axis, dim value should be same OR one of them should be 1
+ */
+TEST(TFLShapeInferenceRuleTest, TFAdd_shapeinf_different)
+{
+ auto g = loco::make_graph();
+
+ auto x_node = g->nodes()->create<loco::Pull>();
+ {
+ x_node->rank(3);
+ x_node->dim(0) = 2;
+ x_node->dim(1) = 1;
+ x_node->dim(2) = 5;
+ }
+ auto y_node = g->nodes()->create<loco::Pull>();
+ {
+ y_node->rank(2);
+ y_node->dim(0) = 3;
+ y_node->dim(1) = 5;
+ }
+ auto tfl_node = g->nodes()->create<locoex::TFLAdd>();
+ {
+ tfl_node->x(x_node);
+ tfl_node->y(y_node);
+ }
+ auto push_node = g->nodes()->create<loco::Push>();
+ {
+ push_node->from(tfl_node);
+ }
+
+ auto x_input = g->inputs()->create();
+ {
+ x_input->name("x");
+ loco::link(x_input, x_node);
+ }
+ auto y_input = g->inputs()->create();
+ {
+ y_input->name("y");
+ loco::link(y_input, y_node);
+ }
+ auto output = g->outputs()->create();
+ {
+ output->name("output");
+ loco::link(output, push_node);
+ }
+
+ // pre-check
+ ASSERT_FALSE(loco::shape_known(tfl_node));
+
+ exo::ShapeInferencePass pass;
+ while (pass.run(g.get()) == true)
+ {
+ ;
+ }
+
+ // Verify
+ {
+ ASSERT_TRUE(loco::shape_known(tfl_node));
+ ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor);
+
+ auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
+ ASSERT_EQ(shape.rank(), 3);
+ ASSERT_EQ(shape.dim(0), 2);
+ ASSERT_EQ(shape.dim(1), 3);
+ ASSERT_EQ(shape.dim(2), 5);
+ }
+}
+
+TEST(TFLShapeInferenceRuleTest, TFLTranspose_simple)
+{
+ exo::test::ExampleGraph<exo::test::ExampleGraphType::TFLTranspose> g;
+
+ g.pull->rank(4);
+ g.pull->dim(0) = 10;
+ g.pull->dim(1) = 20;
+ g.pull->dim(2) = 30;
+ g.pull->dim(3) = 40;
+
+ g.const_perm->dtype(loco::DataType::S32);
+ g.const_perm->rank(1);
+ g.const_perm->dim(0) = 4;
+ g.const_perm->size<loco::DataType::S32>(4);
+ g.const_perm->at<loco::DataType::S32>(0) = 2;
+ g.const_perm->at<loco::DataType::S32>(1) = 3;
+ g.const_perm->at<loco::DataType::S32>(2) = 0;
+ g.const_perm->at<loco::DataType::S32>(3) = 1;
+
+ // pre-check
+ ASSERT_FALSE(loco::shape_known(g.tfl_transpose));
+
+ exo::ShapeInferencePass pass;
+ while (pass.run(g.graph()) == true)
+ ;
+
+ // Verify
+ {
+ ASSERT_TRUE(loco::shape_known(g.tfl_transpose));
+
+ auto shape = loco::shape_get(g.tfl_transpose).as<loco::TensorShape>();
+ ASSERT_EQ(shape.rank(), 4);
+ ASSERT_EQ(shape.dim(0), 30);
+ ASSERT_EQ(shape.dim(1), 40);
+ ASSERT_EQ(shape.dim(2), 10);
+ ASSERT_EQ(shape.dim(3), 20);
+ }
+}
diff --git a/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp
new file mode 100644
index 000000000..3f123a6db
--- /dev/null
+++ b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp
@@ -0,0 +1,141 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLTypeInferenceRule.h"
+
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+#include "Dialect/IR/TFLNodes.h"
+
+#include <cassert>
+
+namespace
+{
+
+struct TypeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::DataType>
+{
+ loco::DataType visit(const locoex::TFLAdd *node) final { return loco::dtype_get(node->x()); }
+
+ loco::DataType visit(const locoex::TFLAveragePool2D *node) final
+ {
+ return loco::dtype_get(node->value());
+ }
+
+ loco::DataType visit(const locoex::TFLConcatenation *node) final
+ {
+ // TODO Support when TFLConcatenation has 0 input
+ assert(node->numValues() > 0);
+
+ for (uint32_t i = 1; i < node->numValues(); ++i)
+ assert(loco::dtype_get(node->values(i - 1)) == loco::dtype_get(node->values(i)));
+
+ return loco::dtype_get(node->values(0));
+ }
+
+ loco::DataType visit(const locoex::TFLConst *node) final { return node->dtype(); }
+
+ loco::DataType visit(const locoex::TFLConv2D *node) final
+ {
+ return loco::dtype_get(node->input());
+ }
+
+ loco::DataType visit(const locoex::TFLDepthwiseConv2D *node) final
+ {
+ return loco::dtype_get(node->input());
+ }
+
+ loco::DataType visit(const locoex::TFLDiv *node) final { return loco::dtype_get(node->x()); }
+
+ loco::DataType visit(const locoex::TFLFullyConnected *node) final
+ {
+ return loco::dtype_get(node->input());
+ }
+
+ loco::DataType visit(const locoex::TFLMaximum *node) final { return loco::dtype_get(node->x()); }
+
+ loco::DataType visit(const locoex::TFLMaxPool2D *node) final
+ {
+ return loco::dtype_get(node->value());
+ }
+
+ loco::DataType visit(const locoex::TFLMean *node) final { return loco::dtype_get(node->input()); }
+
+ loco::DataType visit(const locoex::TFLMul *node) final { return loco::dtype_get(node->x()); }
+
+ loco::DataType visit(const locoex::TFLRelu *node) final
+ {
+ return loco::dtype_get(node->features());
+ }
+
+ loco::DataType visit(const locoex::TFLRelu6 *node) final
+ {
+ return loco::dtype_get(node->features());
+ }
+
+ loco::DataType visit(const locoex::TFLReshape *node) final
+ {
+ return loco::dtype_get(node->tensor());
+ }
+
+ loco::DataType visit(const locoex::TFLRsqrt *node) final { return loco::dtype_get(node->x()); }
+
+ // TODO TFLSoftmax
+
+ loco::DataType visit(const locoex::TFLSqrt *node) final { return loco::dtype_get(node->x()); }
+
+ loco::DataType visit(const locoex::TFLSquaredDifference *node) final
+ {
+ return loco::dtype_get(node->x());
+ }
+
+ loco::DataType visit(const locoex::TFLSub *node) final { return loco::dtype_get(node->x()); }
+
+ // TODO TFLTanh
+
+ loco::DataType visit(const locoex::TFLTranspose *node) final
+ {
+ return loco::dtype_get(node->a());
+ }
+
+ loco::DataType visit(const locoex::TFLTransposeConv *node) final
+ {
+ return loco::dtype_get(node->outBackprop());
+ }
+};
+
+} // namespace
+
+namespace locoex
+{
+
+bool TFLTypeInferenceRule::recognize(const loco::Dialect *d) const
+{
+ return TFLDialect::get() == d;
+}
+
+bool TFLTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const
+{
+ assert(node->dialect() == TFLDialect::get());
+
+ TypeInferenceAlgorithm alg;
+
+ dtype = dynamic_cast<const TFLNode *>(node)->accept(&alg);
+ assert(dtype != loco::DataType::Unknown);
+
+ return true;
+}
+
+} // namespace locoex
diff --git a/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.h b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.h
new file mode 100644
index 000000000..31765dcba
--- /dev/null
+++ b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_SERVICE_TFLTYPE_INFERENCE_RULE_H__
+#define __LOCOEX_SERVICE_TFLTYPE_INFERENCE_RULE_H__
+
+#include <loco/Service/TypeInference.h>
+
+namespace locoex
+{
+
+/**
+ * @brief Type Inference Rule for TFLDialect
+ */
+struct TFLTypeInferenceRule final : public loco::TypeInferenceRule
+{
+ bool recognize(const loco::Dialect *) const final;
+
+ bool infer(const loco::Node *, loco::DataType &) const final;
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_SERVICE_TFLTYPE_INFERENCE_RULE_H__
diff --git a/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.test.cpp b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.test.cpp
new file mode 100644
index 000000000..dd1f93c4d
--- /dev/null
+++ b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.test.cpp
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/Service/TFLTypeInferenceRule.h"
+
+#include "TestGraph.h"
+
+#include <loco.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/Service/TypeInference.h>
+
+#include <stdex/Memory.h>
+
+#include <gtest/gtest.h>
+
+TEST(TFLTypeInferenceRuleTest, minimal_with_TFLRelu)
+{
+ // Create a simple network
+ exo::test::TestGraph graph;
+ auto tfl_node = graph.append<locoex::TFLRelu>(graph.pull);
+ graph.complete(tfl_node);
+
+ graph.pull->dtype(loco::DataType::S32);
+
+ // pre-check
+ ASSERT_FALSE(loco::dtype_known(tfl_node));
+
+ // type inference
+ locoex::TFLTypeInferenceRule tfl_rule;
+ loco::CanonicalTypeInferenceRule canon_rule;
+ loco::MultiDialectTypeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canon_rule);
+ rules.bind(locoex::TFLDialect::get(), &tfl_rule);
+
+ loco::apply(&rules).to(graph.g.get());
+
+ // Verify
+ ASSERT_TRUE(loco::dtype_known(tfl_node));
+ auto type = loco::dtype_get(tfl_node);
+ ASSERT_EQ(type, loco::DataType::S32);
+}
diff --git a/compiler/exo/src/ExoFormattedGraph.cpp b/compiler/exo/src/ExoFormattedGraph.cpp
new file mode 100644
index 000000000..5d3b18be1
--- /dev/null
+++ b/compiler/exo/src/ExoFormattedGraph.cpp
@@ -0,0 +1,525 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ExoFormattedGraph.h"
+
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/IR/TFLNodes.h"
+
+#include "Dialect/IR/CircleDialect.h"
+#include "Dialect/IR/CircleNodes.h"
+
+#include <locoex/Service/COpFormattedGraph.h>
+#include <pepper/str.h>
+
+#include <sstream>
+#include <cassert>
+
+// For TF lite
+namespace
+{
+
+const char *to_str(locoex::FusedActFunc fused)
+{
+ switch (fused)
+ {
+ case locoex::FusedActFunc::NONE:
+ return "NONE";
+ case locoex::FusedActFunc::RELU:
+ return "RELU";
+ case locoex::FusedActFunc::RELU6:
+ return "RELU6";
+ default:
+ return "Error";
+ }
+}
+
+const char *to_str(locoex::Padding padding)
+{
+ switch (padding)
+ {
+ case locoex::Padding::SAME:
+ return "SAME";
+ case locoex::Padding::VALID:
+ return "VALID";
+ default:
+ return "Error";
+ }
+}
+
+std::string to_str(const locoex::Stride *stride)
+{
+ return pepper::str(stride->h(), ",", stride->w());
+}
+
+std::string to_str(const locoex::Filter *filter)
+{
+ return pepper::str(filter->h(), ",", filter->w());
+}
+
+std::string tfl_opname(uint32_t opnum)
+{
+ static std::string prefix{"tfl."};
+
+ switch (static_cast<locoex::TFLOpcode>(opnum))
+ {
+#define TFL_NODE(OPCODE, CLASS) \
+ case locoex::TFLOpcode::OPCODE: \
+ return prefix + #OPCODE;
+#include "Dialect/IR/TFLNodes.lst"
+#undef TFL_NODE
+ default:
+ break;
+ };
+
+ return prefix + "Invalid";
+}
+
+// TFLNodeSummaryBuilder with default implementation
+class TFLNodeSummaryBuilderBase : public locop::NodeSummaryBuilder
+{
+public:
+ TFLNodeSummaryBuilderBase(const locop::SymbolTable *tbl) : _tbl{tbl}
+ {
+ // DO NOTHING
+ }
+
+public:
+ bool build(const loco::Node *, locop::NodeSummary &s) const final;
+
+protected:
+#define TFL_NODE(OPCODE, CLASS) \
+ virtual bool summary(const CLASS *, locop::NodeSummary &s) const \
+ { \
+ s.comments().append("Emitted by Default TFLNodeSummaryBuilder"); \
+ s.state(locop::NodeSummary::State::PartiallyKnown); \
+ return true; \
+ }
+#include "Dialect/IR/TFLNodes.lst"
+#undef TFL_NODE
+
+protected:
+ const locop::SymbolTable *tbl(void) const { return _tbl; }
+
+ // Please do not use _tbl directly and use tbl().
+ // This will be changed to private in near future.
+protected:
+ const locop::SymbolTable *_tbl;
+};
+
+class TFLNodeSummaryBuilder final : public TFLNodeSummaryBuilderBase
+{
+public:
+ TFLNodeSummaryBuilder(const locop::SymbolTable *tbl) : TFLNodeSummaryBuilderBase(tbl)
+ {
+ // DO NOTHING
+ }
+
+private:
+#define IMPLEMENT(CLASS) bool summary(const CLASS *, locop::NodeSummary &) const final;
+ IMPLEMENT(locoex::TFLAdd)
+ IMPLEMENT(locoex::TFLAveragePool2D)
+ IMPLEMENT(locoex::TFLConcatenation)
+ IMPLEMENT(locoex::TFLConst)
+ IMPLEMENT(locoex::TFLConv2D)
+ IMPLEMENT(locoex::TFLDepthwiseConv2D)
+ IMPLEMENT(locoex::TFLDiv)
+ IMPLEMENT(locoex::TFLMaximum)
+ IMPLEMENT(locoex::TFLMaxPool2D)
+ IMPLEMENT(locoex::TFLMean)
+ IMPLEMENT(locoex::TFLMul)
+ IMPLEMENT(locoex::TFLRelu)
+ IMPLEMENT(locoex::TFLRelu6)
+ IMPLEMENT(locoex::TFLReshape)
+ IMPLEMENT(locoex::TFLRsqrt)
+ IMPLEMENT(locoex::TFLSqrt)
+ IMPLEMENT(locoex::TFLSquaredDifference)
+ IMPLEMENT(locoex::TFLSub)
+ IMPLEMENT(locoex::TFLTranspose)
+ IMPLEMENT(locoex::TFLTransposeConv)
+#undef IMPLEMENT
+};
+
+bool TFLNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSummary &s) const
+{
+ if (node->dialect() != locoex::TFLDialect::get())
+ return false;
+
+#define TFL_NODE(OPCODE, CLASS) \
+ if (dynamic_cast<const CLASS *>(node)) \
+ { \
+ s.opname(tfl_opname(node->opnum())); \
+ return summary(dynamic_cast<const CLASS *>(node), s); \
+ }
+#include "Dialect/IR/TFLNodes.lst"
+#undef TFL_NODE
+
+ return false;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLAdd *node, locop::NodeSummary &s) const
+{
+ assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED);
+
+ s.args().append("x", tbl()->lookup(node->x()));
+ s.args().append("y", tbl()->lookup(node->y()));
+ s.args().append("fused_activation_function", to_str(node->fusedActivationFunction()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLAveragePool2D *node,
+ locop::NodeSummary &s) const
+{
+ assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED);
+
+ s.args().append("value", tbl()->lookup(node->value()));
+ s.args().append("filter(h,w)", to_str(node->filter()));
+ s.args().append("stride(h,w)", to_str(node->stride()));
+ s.args().append("padding", to_str(node->padding()));
+ s.args().append("fused", to_str(node->fusedActivationFunction()));
+
+ s.state(locop::NodeSummary::State::Complete);
+
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLConcatenation *node,
+ locop::NodeSummary &s) const
+{
+ assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED);
+
+ for (uint32_t i = 0; i < node->numValues(); ++i)
+ s.args().append("values", tbl()->lookup(node->values(i)));
+ s.args().append("axis", pepper::str(node->axis()));
+ s.args().append("fused", to_str(node->fusedActivationFunction()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLConst *, locop::NodeSummary &s) const
+{
+ s.state(locop::NodeSummary::State::PartiallyKnown);
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLConv2D *node, locop::NodeSummary &s) const
+{
+ assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED);
+ assert(node->padding() != locoex::Padding::UNDEFINED);
+
+ s.args().append("input", tbl()->lookup(node->input()));
+ s.args().append("filter", tbl()->lookup(node->filter()));
+ s.args().append("bias", tbl()->lookup(node->bias()));
+
+ s.args().append("stride(h,w)", to_str(node->stride()));
+ s.args().append("padding", to_str(node->padding()));
+ s.args().append("fused", to_str(node->fusedActivationFunction()));
+
+ s.state(locop::NodeSummary::State::Complete);
+
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLDepthwiseConv2D *node,
+ locop::NodeSummary &s) const
+{
+ assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED);
+ assert(node->padding() != locoex::Padding::UNDEFINED);
+
+ s.args().append("input", tbl()->lookup(node->input()));
+ s.args().append("filter", tbl()->lookup(node->filter()));
+ s.args().append("bias", tbl()->lookup(node->bias()));
+
+ s.args().append("stride(h,w)", to_str(node->stride()));
+ s.args().append("padding", to_str(node->padding()));
+ s.args().append("depthMultiplier", std::to_string(node->depthMultiplier()));
+ s.args().append("fused", to_str(node->fusedActivationFunction()));
+
+ s.state(locop::NodeSummary::State::Complete);
+
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLDiv *node, locop::NodeSummary &s) const
+{
+ s.args().append("x", tbl()->lookup(node->x()));
+ s.args().append("y", tbl()->lookup(node->y()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLMaximum *node, locop::NodeSummary &s) const
+{
+ s.args().append("x", tbl()->lookup(node->x()));
+ s.args().append("y", tbl()->lookup(node->y()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLMaxPool2D *node, locop::NodeSummary &s) const
+{
+ assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED);
+
+ s.args().append("value", tbl()->lookup(node->value()));
+ s.args().append("filter(h,w)", to_str(node->filter()));
+ s.args().append("stride(h,w)", to_str(node->stride()));
+ s.args().append("padding", to_str(node->padding()));
+ s.args().append("fused", to_str(node->fusedActivationFunction()));
+
+ s.state(locop::NodeSummary::State::Complete);
+
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLMean *node, locop::NodeSummary &s) const
+{
+ s.args().append("input", tbl()->lookup(node->input()));
+ s.args().append("reduction_indices", tbl()->lookup(node->reduction_indices()));
+ s.args().append("keep_dims", node->keep_dims() ? "true" : "false");
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLMul *node, locop::NodeSummary &s) const
+{
+ assert(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED);
+
+ s.args().append("x", tbl()->lookup(node->x()));
+ s.args().append("y", tbl()->lookup(node->y()));
+ s.args().append("fused_activation_function", to_str(node->fusedActivationFunction()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLRelu *node, locop::NodeSummary &s) const
+{
+ s.args().append("features", tbl()->lookup(node->features()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLRelu6 *node, locop::NodeSummary &s) const
+{
+ s.args().append("features", tbl()->lookup(node->features()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLReshape *node, locop::NodeSummary &s) const
+{
+ s.args().append("tensor", tbl()->lookup(node->tensor()));
+ s.args().append("shape", tbl()->lookup(node->shape()));
+ // TODO Show newShape info
+ s.state(locop::NodeSummary::State::PartiallyKnown);
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLRsqrt *node, locop::NodeSummary &s) const
+{
+ s.args().append("x", tbl()->lookup(node->x()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+// TODO TFLSoftmax
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLSqrt *node, locop::NodeSummary &s) const
+{
+ s.args().append("x", tbl()->lookup(node->x()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLSquaredDifference *node,
+ locop::NodeSummary &s) const
+{
+ s.args().append("x", tbl()->lookup(node->x()));
+ s.args().append("y", tbl()->lookup(node->y()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLSub *node, locop::NodeSummary &s) const
+{
+ s.args().append("x", tbl()->lookup(node->x()));
+ s.args().append("y", tbl()->lookup(node->y()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+// TODO TFLTanh
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLTranspose *node, locop::NodeSummary &s) const
+{
+ s.args().append("a", tbl()->lookup(node->a()));
+ s.args().append("perm", tbl()->lookup(node->perm()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLTransposeConv *node,
+ locop::NodeSummary &s) const
+{
+ assert(node->padding() != locoex::Padding::UNDEFINED);
+
+ s.args().append("inputSizes", tbl()->lookup(node->inputSizes()));
+ s.args().append("filter", tbl()->lookup(node->filter()));
+ s.args().append("outBackprop", tbl()->lookup(node->outBackprop()));
+
+ s.args().append("stride(h,w)", to_str(node->stride()));
+ s.args().append("padding", to_str(node->padding()));
+
+ s.state(locop::NodeSummary::State::Complete);
+
+ return true;
+}
+
+} // namespace
+
+// For Circle
+namespace
+{
+
+std::string circle_opname(uint32_t opnum)
+{
+ static std::string prefix{"circle."};
+
+ switch (static_cast<locoex::CircleOpcode>(opnum))
+ {
+#define CIRCLE_NODE(OPCODE, CLASS) \
+ case locoex::CircleOpcode::OPCODE: \
+ return prefix + #OPCODE;
+#include "Dialect/IR/CircleNodes.lst"
+#undef CIRCLE_NODE
+ default:
+ break;
+ };
+
+ return prefix + "Invalid";
+}
+
+// CircleNodeSummaryBuilder with default implementation
+class CircleNodeSummaryBuilderBase : public locop::NodeSummaryBuilder
+{
+public:
+ CircleNodeSummaryBuilderBase(const locop::SymbolTable *tbl) : _tbl{tbl}
+ {
+ // DO NOTHING
+ }
+
+public:
+ bool build(const loco::Node *, locop::NodeSummary &s) const final;
+
+protected:
+#define CIRCLE_NODE(OPCODE, CLASS) \
+ virtual bool summary(const CLASS *, locop::NodeSummary &s) const \
+ { \
+ s.comments().append("Emitted by Default CircleNodeSummaryBuilder"); \
+ s.state(locop::NodeSummary::State::PartiallyKnown); \
+ return true; \
+ }
+#include "Dialect/IR/CircleNodes.lst"
+#undef CIRCLE_NODE
+
+protected:
+ const locop::SymbolTable *tbl(void) const { return _tbl; }
+
+ // Please do not use _tbl directly and use tbl().
+ // This will be changed to private in near future.
+protected:
+ const locop::SymbolTable *_tbl;
+};
+
+class CircleNodeSummaryBuilder final : public CircleNodeSummaryBuilderBase
+{
+public:
+ CircleNodeSummaryBuilder(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
+ {
+ // DO NOTHING
+ }
+
+private:
+#define IMPLEMENT(CLASS) bool summary(const CLASS *, locop::NodeSummary &) const final;
+ IMPLEMENT(locoex::CircleInstanceNorm)
+#undef IMPLEMENT
+};
+
+bool CircleNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSummary &s) const
+{
+ if (node->dialect() != locoex::CircleDialect::get())
+ return false;
+
+#define CIRCLE_NODE(OPCODE, CLASS) \
+ if (dynamic_cast<const CLASS *>(node)) \
+ { \
+ s.opname(circle_opname(node->opnum())); \
+ return summary(dynamic_cast<const CLASS *>(node), s); \
+ }
+#include "Dialect/IR/CircleNodes.lst"
+#undef CIRCLE_NODE
+
+ return false;
+}
+
+bool CircleNodeSummaryBuilder::summary(const locoex::CircleInstanceNorm *node,
+ locop::NodeSummary &s) const
+{
+ auto fused = node->fusedActivationFunction();
+ assert(fused != locoex::FusedActFunc::UNDEFINED);
+
+ s.args().append("input", tbl()->lookup(node->input()));
+ s.args().append("gamma", tbl()->lookup(node->gamma()));
+ s.args().append("beta", tbl()->lookup(node->beta()));
+ s.args().append("epsilon", pepper::str(node->epsilon()));
+ s.args().append("fused_activation_function", to_str(fused));
+
+ s.state(locop::NodeSummary::State::Complete);
+
+ return true;
+}
+
+} // namespace
+
+namespace exo
+{
+
+bool NodeSummaryBuilder::build(const loco::Node *node, locop::NodeSummary &s) const
+{
+ if (locop::CanonicalNodeSummaryBuilder(_tbl).build(node, s))
+ {
+ return true;
+ }
+
+ if (TFLNodeSummaryBuilder(_tbl).build(node, s))
+ {
+ return true;
+ }
+
+ if (CircleNodeSummaryBuilder(_tbl).build(node, s))
+ {
+ return true;
+ }
+
+ if (locoex::COpNodeSummaryBuilder(_tbl).build(node, s))
+ {
+ return true;
+ }
+
+ return false;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/ExoFormattedGraph.h b/compiler/exo/src/ExoFormattedGraph.h
new file mode 100644
index 000000000..714e483b5
--- /dev/null
+++ b/compiler/exo/src/ExoFormattedGraph.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __EXO_FORMATTED_GRAPH_H__
+#define __EXO_FORMATTED_GRAPH_H__
+
+#include <locop/FormattedGraph.h>
+
+#include <stdex/Memory.h>
+
+namespace exo
+{
+
+class NodeSummaryBuilder final : public locop::NodeSummaryBuilder
+{
+public:
+ NodeSummaryBuilder(const locop::SymbolTable *tbl) : _tbl{tbl}
+ {
+ // DO NOTHING
+ }
+
+public:
+ bool build(const loco::Node *node, locop::NodeSummary &s) const final;
+
+private:
+ const locop::SymbolTable *_tbl;
+};
+
+class NodeSummaryBuilderFactory final : public locop::NodeSummaryBuilderFactory
+{
+public:
+ NodeSummaryBuilderFactory() = default;
+
+public:
+ std::unique_ptr<locop::NodeSummaryBuilder> create(const locop::SymbolTable *tlb) const final
+ {
+ return stdex::make_unique<NodeSummaryBuilder>(tlb);
+ }
+};
+
+} // namespace exo
+
+#endif // __EXO_FORMATTED_GRAPH_H__
diff --git a/compiler/exo/src/ExoOptimize.cpp b/compiler/exo/src/ExoOptimize.cpp
new file mode 100644
index 000000000..d7278e900
--- /dev/null
+++ b/compiler/exo/src/ExoOptimize.cpp
@@ -0,0 +1,74 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ExoOptimize.h"
+
+#include "Knob.h"
+#include "Passes.h"
+#include "ProgressReporter.h"
+
+#include <logo/Phase.h>
+
+#include <stdex/Memory.h>
+
+namespace exo
+{
+
+void optimize(loco::Graph *g)
+{
+ logo::Phase phase;
+ {
+ // prepare type and shape before optimization
+ phase.emplace_back(stdex::make_unique<TypeInferencePass>());
+ phase.emplace_back(stdex::make_unique<ShapeInferencePass>());
+
+ phase.emplace_back(stdex::make_unique<FoldReshapeOfConstPass>());
+ phase.emplace_back(stdex::make_unique<FoldTransposeOfConstPass>());
+
+ if (get<Knob::UseFuseBiasAddPass>())
+ {
+ phase.emplace_back(stdex::make_unique<FuseBiasAddPass>());
+ }
+
+ if (get<Knob::UseFuseInstanceNormPass>())
+ {
+ phase.emplace_back(stdex::make_unique<FuseInstanceNormPass>());
+ }
+
+ if (get<Knob::UseFuseReluPass>())
+ {
+ phase.emplace_back(stdex::make_unique<FuseReluPass>());
+ }
+ phase.emplace_back(stdex::make_unique<FuseRsqrtPass>());
+
+ if (get<Knob::UseFuseSquaredDifferencePass>())
+ {
+ phase.emplace_back(stdex::make_unique<FuseSquaredDifferencePass>());
+ }
+
+ phase.emplace_back(stdex::make_unique<MergeConcatNodesPass>());
+
+ phase.emplace_back(stdex::make_unique<logo::RemoveDeadNodePass>());
+ }
+
+ logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Restart);
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/ExoOptimize.h b/compiler/exo/src/ExoOptimize.h
new file mode 100644
index 000000000..4769c1193
--- /dev/null
+++ b/compiler/exo/src/ExoOptimize.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OPTIMIZE_H__
+#define __OPTIMIZE_H__
+
+#include <loco.h>
+
+namespace exo
+{
+
+/**
+ * @brief Run passes for a graph after completion of converting canonical nodes into TFL nodes.
+ *
+ * TODO Separate optimize pass dedicated to TFL and Circle dialect when necessary
+ */
+void optimize(loco::Graph *);
+
+} // namespace exo
+
+#endif // __OPTIMIZE_H__
diff --git a/compiler/exo/src/ExporterUtils.cpp b/compiler/exo/src/ExporterUtils.cpp
new file mode 100644
index 000000000..41ccdcd71
--- /dev/null
+++ b/compiler/exo/src/ExporterUtils.cpp
@@ -0,0 +1,139 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ExporterUtils.h"
+
+#include <oops/InternalExn.h>
+
+#include <cassert>
+
+namespace exo
+{
+
+ShapeDescription to_shape_description(const loco::TensorShape &shape)
+{
+ ShapeDescription res;
+
+ res._rank_known = true;
+
+ res._dims.resize(shape.rank());
+ for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+ {
+ // All the dimensions SHOULD be known
+ assert(shape.dim(axis).known());
+ res._dims.at(axis) = shape.dim(axis).value();
+ }
+
+ return res;
+}
+
+ShapeDescription to_shape_description(const loco::FeatureShape &shape)
+{
+ ShapeDescription res;
+
+ res._rank_known = true;
+
+ // T/F Lite encodes a feature map as a NHWC tensor
+ res._dims.resize(4);
+ res._dims.at(0) = shape.count().value();
+ res._dims.at(1) = shape.height().value();
+ res._dims.at(2) = shape.width().value();
+ res._dims.at(3) = shape.depth().value();
+
+ return res;
+}
+
+ShapeDescription to_shape_description(const loco::FilterShape &shape)
+{
+ ShapeDescription res;
+
+ res._rank_known = true;
+
+ // T/F Lite encodes a convolution filter as a NHWC tensor
+ res._dims.resize(4);
+ res._dims.at(0) = shape.count().value();
+ res._dims.at(1) = shape.height().value();
+ res._dims.at(2) = shape.width().value();
+ res._dims.at(3) = shape.depth().value();
+
+ return res;
+}
+
+ShapeDescription to_shape_description(const loco::DepthwiseFilterShape &shape)
+{
+ ShapeDescription res;
+
+ res._rank_known = true;
+
+ // T/F Lite encodes a depthwise convolution filter as a [1, H, W, C*M] tensor
+ res._dims.resize(4);
+ res._dims.at(0) = 1;
+ res._dims.at(1) = shape.height().value();
+ res._dims.at(2) = shape.width().value();
+ res._dims.at(3) = shape.depth().value() * shape.multiplier().value();
+
+ return res;
+}
+
+ShapeDescription to_shape_description(const loco::BiasShape &shape)
+{
+ ShapeDescription res;
+
+ res._rank_known = true;
+
+ res._dims.resize(1);
+ res._dims.at(0) = shape.length().value();
+
+ return res;
+}
+
+ShapeDescription to_shape_description(const loco::MatrixShape &shape)
+{
+ ShapeDescription res;
+
+ res._rank_known = true;
+
+ res._dims.resize(2);
+ res._dims.at(0) = shape.height().value();
+ res._dims.at(1) = shape.width().value();
+
+ return res;
+}
+
+ShapeDescription to_shape_description(const loco::NodeShape &shape)
+{
+ switch (shape.domain())
+ {
+ case loco::Domain::Tensor:
+ return to_shape_description(shape.as<loco::TensorShape>());
+ case loco::Domain::Feature:
+ return to_shape_description(shape.as<loco::FeatureShape>());
+ case loco::Domain::Filter:
+ return to_shape_description(shape.as<loco::FilterShape>());
+ case loco::Domain::DepthwiseFilter:
+ return to_shape_description(shape.as<loco::DepthwiseFilterShape>());
+ case loco::Domain::Bias:
+ return to_shape_description(shape.as<loco::BiasShape>());
+ case loco::Domain::Matrix:
+ return to_shape_description(shape.as<loco::MatrixShape>());
+ default:
+ break;
+ }
+
+ INTERNAL_EXN_V("Unsupported loco domain", oops::to_uint32(shape.domain()));
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/ExporterUtils.h b/compiler/exo/src/ExporterUtils.h
new file mode 100644
index 000000000..e1f1f66a8
--- /dev/null
+++ b/compiler/exo/src/ExporterUtils.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __EXPORTER_UTILS_H__
+#define __EXPORTER_UTILS_H__
+
+#include "loco.h"
+
+#include "loco/IR/PermutingCodec.h"
+#include "loco/IR/NodeShape.h"
+
+namespace exo
+{
+
+struct ShapeDescription
+{
+ std::vector<int32_t> _dims;
+ bool _rank_known;
+};
+
+ShapeDescription to_shape_description(const loco::TensorShape &shape);
+ShapeDescription to_shape_description(const loco::FeatureShape &shape);
+ShapeDescription to_shape_description(const loco::FilterShape &shape);
+ShapeDescription to_shape_description(const loco::BiasShape &shape);
+ShapeDescription to_shape_description(const loco::MatrixShape &shape);
+ShapeDescription to_shape_description(const loco::NodeShape &shape);
+
+template <typename Permutation> inline bool isNHWC(Permutation *perm);
+
+template <> inline bool isNHWC(loco::Permutation<loco::Domain::Feature> *perm)
+{
+ return perm->axis(loco::FeatureAxis::Count) == 0 && perm->axis(loco::FeatureAxis::Height) == 1 &&
+ perm->axis(loco::FeatureAxis::Width) == 2 && perm->axis(loco::FeatureAxis::Depth) == 3;
+}
+
+template <> inline bool isNHWC(loco::Permutation<loco::Domain::Filter> *perm)
+{
+ return perm->axis(loco::FilterAxis::Count) == 0 && perm->axis(loco::FilterAxis::Height) == 1 &&
+ perm->axis(loco::FilterAxis::Width) == 2 && perm->axis(loco::FilterAxis::Depth) == 3;
+}
+
+} // namespace exo
+
+#endif // __EXPORTER_UTILS_H__
diff --git a/compiler/exo/src/GraphBlock.cpp b/compiler/exo/src/GraphBlock.cpp
new file mode 100644
index 000000000..0a45ce8ad
--- /dev/null
+++ b/compiler/exo/src/GraphBlock.cpp
@@ -0,0 +1,243 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GraphBlock.h"
+
+#include "Check.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+
+namespace
+{
+
+template <exo::FeatureLayout T> loco::Permutation<loco::Domain::Feature> perm();
+
+template <> loco::Permutation<loco::Domain::Feature> perm<exo::FeatureLayout::NHWC>()
+{
+ // Make NHWC permutation for encoder and decoder
+ loco::Permutation<loco::Domain::Feature> NHWC;
+
+ NHWC.axis(loco::FeatureAxis::Count) = 0;
+ NHWC.axis(loco::FeatureAxis::Height) = 1;
+ NHWC.axis(loco::FeatureAxis::Width) = 2;
+ NHWC.axis(loco::FeatureAxis::Depth) = 3;
+
+ return NHWC;
+}
+
+template <exo::FilterLayout T> loco::Permutation<loco::Domain::Filter> perm();
+
+template <> loco::Permutation<loco::Domain::Filter> perm<exo::FilterLayout::HWIO>()
+{
+ loco::Permutation<loco::Domain::Filter> HWIO; // a.k.a., HWCN
+
+ HWIO.axis(loco::FilterAxis::Height) = 0;
+ HWIO.axis(loco::FilterAxis::Width) = 1;
+ HWIO.axis(loco::FilterAxis::Depth) = 2;
+ HWIO.axis(loco::FilterAxis::Count) = 3;
+
+ return HWIO;
+}
+
+template <> loco::Permutation<loco::Domain::Filter> perm<exo::FilterLayout::OHWI>()
+{
+
+ // Make NHWC permutation for encoder and decoder
+ loco::Permutation<loco::Domain::Filter> OHWI; // a.k.a., NHWC
+
+ OHWI.axis(loco::FilterAxis::Count) = 0;
+ OHWI.axis(loco::FilterAxis::Height) = 1;
+ OHWI.axis(loco::FilterAxis::Width) = 2;
+ OHWI.axis(loco::FilterAxis::Depth) = 3;
+
+ return OHWI;
+}
+
+template <exo::DepthwiseFilterLayout T> loco::Permutation<loco::Domain::DepthwiseFilter> perm();
+
+template <>
+loco::Permutation<loco::Domain::DepthwiseFilter> perm<exo::DepthwiseFilterLayout::HWCM>()
+{
+ loco::Permutation<loco::Domain::DepthwiseFilter> HWCM;
+
+ HWCM.axis(loco::DepthwiseFilterAxis::Height) = 0;
+ HWCM.axis(loco::DepthwiseFilterAxis::Width) = 1;
+ HWCM.axis(loco::DepthwiseFilterAxis::Depth) = 2;
+ HWCM.axis(loco::DepthwiseFilterAxis::Multiplier) = 3;
+
+ return HWCM;
+}
+
+template <exo::MatrixLayout T> loco::Permutation<loco::Domain::Matrix> perm();
+
+template <> loco::Permutation<loco::Domain::Matrix> perm<exo::MatrixLayout::HW>()
+{
+ loco::Permutation<loco::Domain::Matrix> HW;
+
+ HW.axis(loco::MatrixAxis::Height) = 0;
+ HW.axis(loco::MatrixAxis::Width) = 1;
+
+ return HW;
+}
+
+template <> loco::Permutation<loco::Domain::Matrix> perm<exo::MatrixLayout::WH>()
+{
+ loco::Permutation<loco::Domain::Matrix> WH;
+
+ WH.axis(loco::MatrixAxis::Height) = 1;
+ WH.axis(loco::MatrixAxis::Width) = 0;
+
+ return WH;
+}
+
+} // namespace
+
+namespace exo
+{
+
+template <FeatureLayout T> loco::FeatureEncode *make_feature_encode(loco::Node *input_for_encode)
+{
+ EXO_ASSERT(input_for_encode != nullptr, "input should not be nullptr");
+ loco::Graph *g = input_for_encode->graph();
+
+ auto encoder = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ encoder->perm(perm<T>());
+
+ auto enc = g->nodes()->create<loco::FeatureEncode>();
+ enc->input(input_for_encode);
+ enc->encoder(std::move(encoder));
+
+ return enc;
+}
+
+template <FeatureLayout T> loco::FeatureDecode *make_feature_decode(loco::Node *input_for_decode)
+{
+ EXO_ASSERT(input_for_decode != nullptr, "input should not be nullptr");
+ loco::Graph *g = input_for_decode->graph();
+
+ auto decoder = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ decoder->perm(perm<T>());
+
+ auto dec = g->nodes()->create<loco::FeatureDecode>();
+ dec->input(input_for_decode);
+ dec->decoder(std::move(decoder));
+
+ return dec;
+}
+
+template <FilterLayout T> loco::FilterEncode *make_filter_encode(loco::Node *input_for_encode)
+{
+ EXO_ASSERT(input_for_encode != nullptr, "filter should not be nullptr");
+ loco::Graph *g = input_for_encode->graph();
+
+ auto encoder = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>();
+
+ encoder->perm(perm<T>());
+
+ auto enc = g->nodes()->create<loco::FilterEncode>();
+ enc->input(input_for_encode);
+ enc->encoder(std::move(encoder));
+
+ return enc;
+}
+
+template <FilterLayout T> loco::FilterDecode *make_filter_decode(loco::Node *input_for_decode)
+{
+ EXO_ASSERT(input_for_decode != nullptr, "filter should not be nullptr");
+ loco::Graph *g = input_for_decode->graph();
+
+ auto decoder = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Filter>>();
+
+ decoder->perm(perm<T>());
+
+ auto dec = g->nodes()->create<loco::FilterDecode>();
+ dec->input(input_for_decode);
+ dec->decoder(std::move(decoder));
+
+ return dec;
+}
+
+template <DepthwiseFilterLayout T>
+loco::DepthwiseFilterDecode *make_dw_filter_decode(loco::Node *input_for_decode)
+{
+ EXO_ASSERT(input_for_decode != nullptr, "filter should not be nullptr");
+ loco::Graph *g = input_for_decode->graph();
+
+ auto decoder = stdex::make_unique<loco::PermutingDecoder<loco::Domain::DepthwiseFilter>>();
+
+ decoder->perm(perm<T>());
+
+ auto dec = g->nodes()->create<loco::DepthwiseFilterDecode>();
+ dec->input(input_for_decode);
+ dec->decoder(std::move(decoder));
+
+ return dec;
+}
+
+template <MatrixLayout T> loco::MatrixEncode *make_matrix_encode(loco::Node *input_for_encode)
+{
+ EXO_ASSERT(input_for_encode != nullptr, "input should not be nullptr");
+ loco::Graph *g = input_for_encode->graph();
+
+ auto encoder = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Matrix>>();
+
+ encoder->perm(perm<T>());
+
+ auto enc = g->nodes()->create<loco::MatrixEncode>();
+ enc->input(input_for_encode);
+ enc->encoder(std::move(encoder));
+
+ return enc;
+}
+
+template <MatrixLayout T> loco::MatrixDecode *make_matrix_decode(loco::Node *input_for_decode)
+{
+ EXO_ASSERT(input_for_decode != nullptr, "input should not be nullptr");
+ loco::Graph *g = input_for_decode->graph();
+
+ auto decoder = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Matrix>>();
+
+ decoder->perm(perm<T>());
+
+ auto dec = g->nodes()->create<loco::MatrixDecode>();
+ dec->input(input_for_decode);
+ dec->decoder(std::move(decoder));
+
+ return dec;
+}
+
+// template instantiation
+template loco::FeatureEncode *
+make_feature_encode<FeatureLayout::NHWC>(loco::Node *input_for_encode);
+
+template loco::FeatureDecode *
+make_feature_decode<FeatureLayout::NHWC>(loco::Node *input_for_encode);
+
+template loco::FilterEncode *make_filter_encode<FilterLayout::HWIO>(loco::Node *input_for_encode);
+template loco::FilterDecode *make_filter_decode<FilterLayout::OHWI>(loco::Node *input_for_decode);
+
+template loco::DepthwiseFilterDecode *
+make_dw_filter_decode<DepthwiseFilterLayout::HWCM>(loco::Node *input_for_decode);
+
+template loco::MatrixEncode *make_matrix_encode<MatrixLayout::HW>(loco::Node *input_for_encode);
+template loco::MatrixEncode *make_matrix_encode<MatrixLayout::WH>(loco::Node *input_for_encode);
+template loco::MatrixDecode *make_matrix_decode<MatrixLayout::HW>(loco::Node *input_for_decode);
+template loco::MatrixDecode *make_matrix_decode<MatrixLayout::WH>(loco::Node *input_for_decode);
+
+} // namespace exo
diff --git a/compiler/exo/src/GraphBlock.h b/compiler/exo/src/GraphBlock.h
new file mode 100644
index 000000000..b771c821b
--- /dev/null
+++ b/compiler/exo/src/GraphBlock.h
@@ -0,0 +1,199 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __GRAPH_BLOCK_H__
+#define __GRAPH_BLOCK_H__
+
+#include <loco.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <oops/InternalExn.h>
+
+#include <functional>
+
+namespace exo
+{
+
+/// @brief feature layout of TFLITE file
+enum class FeatureLayout
+{
+ NHWC,
+};
+
+/// @brief Creates a loco::FeatureEncode with T layout (NHWC for tflite) and add it to graph.
+template <FeatureLayout T> loco::FeatureEncode *make_feature_encode(loco::Node *input_for_encode);
+
+/// @brief Creates a loco::FeatureDecode with T layout (NHWC for tflite) and add it to graph.
+template <FeatureLayout T> loco::FeatureDecode *make_feature_decode(loco::Node *input_for_decode);
+
+enum class FilterLayout
+{
+ OHWI, // a.k.a., NHWC, Tensorflow Lite uses this layout for filter
+ HWIO, // a.k.a., HWCN, Tensorflow uses this layout for filter
+};
+
+/// @brief Create a loco::FilterEncode of given layout
+template <FilterLayout T> loco::FilterEncode *make_filter_encode(loco::Node *input_for_encode);
+
+/// @brief Create a loco::FilterDecode of given layout
+template <FilterLayout T> loco::FilterDecode *make_filter_decode(loco::Node *input_for_decode);
+
+enum class DepthwiseFilterLayout
+{
+ HWCM,
+};
+
+/// @brief Create a loco::DepthwiseFilterDecode of given layout
+template <DepthwiseFilterLayout T>
+loco::DepthwiseFilterDecode *make_dw_filter_decode(loco::Node *input_for_decode);
+
+enum class MatrixLayout
+{
+ HW,
+ WH
+};
+
+/// @brief Create a loco::MatrixEncode of given layout
+template <MatrixLayout T> loco::MatrixEncode *make_matrix_encode(loco::Node *input_for_encode);
+
+/// @brief Create a loco::MatrixDecode of given layout
+template <MatrixLayout T> loco::MatrixDecode *make_matrix_decode(loco::Node *input_for_decode);
+
+} // exo
+
+//
+// DomainConverter
+//
+
+/**
+ * Some canonical nodes can have input of various loco::Domain, e.g., loco::Domain::Tensor,
+ * loco::Domain::Feature, etc. However, TFL node accepts only loco::Domain::Tensor.
+ * So, When converting such canonical node to TFL node and input(s) of a canonical node are not
+ * loco::Domain::Tensor, additional nodes need to be inserted.
+ *
+ * The following two classes helps this insertion.
+ *
+ * For example, in case of loco::Relu conversion,
+ *
+ * Before:
+ *
+ * A (output: feature) -- loco::ReLU --- B (input:feature)
+ *
+ * After:
+ *
+ * A -- loco::FeatureDecode -- locoex::TFLRelu -- loco::FeatureEncode --- B
+ *
+ * loco::ReLU (dead node)
+ */
+
+namespace exo
+{
+
+/**
+ * @brief Handles input(s) while converting a canonical node to TFL node(s).
+ * This class informs DomainConverter how to handle inputs of a specific canonical node.
+ */
+template <class CanonicalT, class TFLT> class InputHandler
+{
+public:
+ /**
+ * @brief Assign origin's inputs to replacer's inputs.
+ * (This is called when origin belongs in Tensor domain.)
+ */
+ virtual void handover(CanonicalT *origin, TFLT *replacer) = 0;
+
+ /**
+ * @brief Returns the list of inputs that needs to have FeatureDecode as its input.
+ * (This is called when origin belongs in Feature domain.)
+ */
+ virtual std::vector<loco::Node *> getInputsToConvert(CanonicalT *origin) = 0;
+
+ /// @brief Set the inputs of replacer to new_inputs
+ virtual void set(TFLT *replacer, std::vector<loco::Node *> &new_inputs) = 0;
+
+ /// @brief Set the inputs to nullptr
+ virtual void nullify(CanonicalT *origin) = 0;
+};
+
+/**
+ * @brief Class to handle domain conversion while converting a canonical node to TFL node(s)
+ */
+template <class CanonicalT, class TFLT> class DomainConverter
+{
+public:
+ template <FeatureLayout FeatureLayoutT>
+ TFLT *convert(CanonicalT *origin, InputHandler<CanonicalT, TFLT> &input_handler);
+};
+
+/**
+ * @brief Performs domain conversion
+ *
+ * 1. if origin belong to loco::Domain::Tensor, and replace origin to a TFL node.
+ * 2. if origin belong to loco::Domain::Feature, insert loco::FeatureDecode for input(s) and
+ * insert loco::FeatureEncode for output. Then replace origin to a TFL node.
+ *
+ * @return new TFL node; nullptr if shape of origin cannot be known
+ */
+template <class CanonicalT, class TFLT>
+template <FeatureLayout FeatureLayoutT>
+TFLT *DomainConverter<CanonicalT, TFLT>::convert(CanonicalT *origin,
+ InputHandler<CanonicalT, TFLT> &input_handler)
+{
+ static_assert(FeatureLayoutT == FeatureLayout::NHWC, "Feature layout should be NHWC");
+
+ if (!loco::shape_known(origin))
+ {
+ return nullptr;
+ }
+
+ auto tfl_node = origin->graph()->nodes()->template create<TFLT>();
+
+ // when the input is Tensor, just replace canonical node to TFL node.
+ if (loco::shape_get(origin).domain() == loco::Domain::Tensor)
+ {
+ input_handler.handover(origin, tfl_node);
+
+ loco::replace(origin).with(tfl_node);
+ input_handler.nullify(origin);
+
+ return tfl_node;
+ }
+ else if (loco::shape_get(origin).domain() == loco::Domain::Feature)
+ {
+ std::vector<loco::Node *> feature_decodes;
+
+ for (auto input : input_handler.getInputsToConvert(origin))
+ {
+ auto dec = make_feature_decode<FeatureLayoutT>(input);
+ feature_decodes.emplace_back(dec);
+ }
+
+ input_handler.set(tfl_node, feature_decodes);
+
+ auto enc = make_feature_encode<FeatureLayoutT>(tfl_node);
+
+ loco::replace(origin).with(enc);
+ input_handler.nullify(origin);
+
+ return tfl_node;
+ }
+ else
+ INTERNAL_EXN_V("Unsupported loco::Domain", oops::to_uint32(loco::shape_get(origin).domain()));
+}
+
+} // namespace exo
+
+#endif //__GRAPH_BLOCK_H__
diff --git a/compiler/exo/src/Knob.cpp b/compiler/exo/src/Knob.cpp
new file mode 100644
index 000000000..50d78f4b7
--- /dev/null
+++ b/compiler/exo/src/Knob.cpp
@@ -0,0 +1,122 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Knob.h"
+
+#include <pepper/strcast.h>
+
+#include <iostream>
+#include <string>
+#include <map>
+
+// Basic Infrastructure to declare and access Knob values
+namespace
+{
+
+using KnobName = std::string;
+
+/**
+ * @brief Load configuration (from somewhere)
+ */
+struct KnobLoader
+{
+ virtual ~KnobLoader() = default;
+
+ virtual bool load(const KnobName &name, bool default_value) const = 0;
+};
+
+/**
+ * @brief Load configuration from environment variables
+ *
+ * Given a prefix P, EnvKnobLoader reads a configuration K from concat(P, K).
+ *
+ * For example, let us assume that P is "MY_" and K is "CONFIG".
+ *
+ * Then, EnvKnobLoader reads configuration CONFIG from environment variable MY_CONFIG.
+ */
+class EnvKnobLoader final : public KnobLoader
+{
+public:
+ EnvKnobLoader() = default;
+
+public:
+ bool load(const KnobName &knob_name, bool default_value) const override
+ {
+ auto envvar = _prefix + knob_name;
+ auto s = std::getenv(envvar.c_str());
+
+ return pepper::safe_strcast<int>(s, default_value ? 1 : 0) != 0;
+ }
+ void knob_set(const KnobName &knob_name, bool value) { _knob[knob_name] = value; }
+ void dialect_set(const exo::Dialect &dialect_name) { _prefix = _label[dialect_name]; }
+ bool knob_get(const KnobName &knob_name) { return load(knob_name, _knob[knob_name]); }
+
+private:
+ /// @brief Environment variable prefix
+ std::string _prefix;
+ std::map<KnobName, bool> _knob;
+ std::map<exo::Dialect, KnobName> _label = {{exo::Dialect::TFLITE, "TFL_"},
+ {exo::Dialect::CIRCLE, "CIRCLE_"}};
+};
+
+} // namespace
+
+namespace
+{
+
+EnvKnobLoader &knob_loader(void)
+{
+ // TODO separate "EXOTFLITE_" and "EXOCIRCLE_" when necessary
+ static EnvKnobLoader loader;
+ return loader;
+}
+
+} // namespace
+
+namespace exo
+{
+
+#define KNOB_BOOL(NAME, TFL_DEFAULT, CIRCLE_DEFAULT, DESC) \
+ template <> typename KnobTrait<Knob::NAME>::ValueType get<Knob::NAME>(void) \
+ { \
+ return ::knob_loader().knob_get(#NAME); \
+ }
+#include "Knob.lst"
+#undef KNOB_BOOL
+
+void set(Dialect d)
+{
+ ::knob_loader().dialect_set(d);
+ switch (d)
+ {
+ case Dialect::TFLITE:
+#define KNOB_BOOL(NAME, TFL_DEFAULT, CIRCLE_DEFAULT, DESC) \
+ ::knob_loader().knob_set(#NAME, TFL_DEFAULT);
+#include "Knob.lst"
+#undef KNOB_BOOL
+ break;
+ case Dialect::CIRCLE:
+#define KNOB_BOOL(NAME, TFL_DEFAULT, CIRCLE_DEFAULT, DESC) \
+ ::knob_loader().knob_set(#NAME, CIRCLE_DEFAULT);
+#include "Knob.lst"
+#undef KNOB_BOOL
+ break;
+ default:
+ std::runtime_error("UnKnown dialect");
+ }
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Knob.h b/compiler/exo/src/Knob.h
new file mode 100644
index 000000000..98613120c
--- /dev/null
+++ b/compiler/exo/src/Knob.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __KNOB_H__
+#define __KNOB_H__
+
+namespace exo
+{
+
+enum class Dialect
+{
+ TFLITE,
+ CIRCLE
+};
+
+enum class Knob
+{
+#define KNOB_BOOL(NAME, TFL_DEFAULT, CIRCLE_DEFAULT, DESC) NAME,
+#include "Knob.lst"
+#undef KNOB_BOOL
+};
+
+template <Knob K> struct KnobTrait;
+
+#define KNOB_BOOL(NAME, TFL_DEFAULT, CIRCLE_DEFAULT, DESC) \
+ template <> struct KnobTrait<Knob::NAME> \
+ { \
+ using ValueType = bool; \
+ };
+#include "Knob.lst"
+#undef KNOB_BOOL
+
+template <Knob K> typename KnobTrait<K>::ValueType get(void);
+void set(Dialect);
+
+} // namespace exo
+
+#endif // __KNOB_H__
diff --git a/compiler/exo/src/Knob.lst b/compiler/exo/src/Knob.lst
new file mode 100644
index 000000000..7f59c93f3
--- /dev/null
+++ b/compiler/exo/src/Knob.lst
@@ -0,0 +1,11 @@
+#ifndef KNOB_BOOL
+#error "KNOB_BOOL is not defined"
+#endif // KNOB_BOOL
+
+// KNOB_BOOL(KNOB_NAME, TFL_DEFAULT, CIRCLE_DEFAULT, DESCRIPTION)
+
+// Optimization pass
+KNOB_BOOL(UseFuseBiasAddPass, true, true, Fuse TFLAdd or TFLSub into TFLConv2D)
+KNOB_BOOL(UseFuseInstanceNormPass, false, true, Fuse InstanceNorm pattern)
+KNOB_BOOL(UseFuseReluPass, true, true, Fuse TFLAdd or TFLSub into TFLConv2D or so)
+KNOB_BOOL(UseFuseSquaredDifferencePass, false, true, Fuse SquaredDifference pattern)
diff --git a/compiler/exo/src/Log.cpp b/compiler/exo/src/Log.cpp
new file mode 100644
index 000000000..aa762968b
--- /dev/null
+++ b/compiler/exo/src/Log.cpp
@@ -0,0 +1,84 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Log.h"
+
+#include <hermes/ConsoleReporter.h>
+#include <stdex/Memory.h>
+
+#include <cstdlib>
+#include <iostream>
+
+// TODO Extract these lexical conversion routines as a library
+namespace
+{
+
+/**
+ * @brief Convert C-string as a value of type T
+ *
+ * safecast(s, v) returns v if s is nullptr.
+ */
+template <typename T> T safecast(const char *, const T &);
+
+template <> bool safecast<bool>(const char *s, const bool &value)
+{
+ return (s == nullptr) ? value : (std::stoi(s) != 0);
+}
+
+} // namespace
+
+namespace exo
+{
+
+//
+// Logger
+//
+Logger::Logger(hermes::Context *ctx) { activate(ctx->sources(), ctx->bus()); }
+Logger::~Logger() { deactivate(); }
+
+//
+// LoggerConfig
+//
+LoggerConfig::LoggerConfig()
+{
+ // Turn on logging if EXO_LOG is set as non-zero value
+ _enabled = safecast<bool>(std::getenv("EXO_LOG"), false);
+}
+
+void LoggerConfig::configure(const hermes::Source *source, hermes::Source::Setting &setting) const
+{
+ // Let's ignore hermes::Sources if that is not a exo logger
+ if (auto logger = dynamic_cast<const Logger *>(source))
+ {
+ configure(logger, setting);
+ }
+}
+
+void LoggerConfig::configure(const Logger *, hermes::Source::Setting &setting) const
+{
+ if (_enabled)
+ {
+ // Enable all catagories
+ setting.accept_all();
+ }
+ else
+ {
+ // Disable all catagories
+ setting.reject_all();
+ }
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Log.h b/compiler/exo/src/Log.h
new file mode 100644
index 000000000..8ca38c3ec
--- /dev/null
+++ b/compiler/exo/src/Log.h
@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOG_H__
+#define __LOG_H__
+
+#include "exo/LoggingContext.h"
+
+#include <hermes.h>
+
+namespace exo
+{
+
+/**
+ * @brief Logger Implementation
+ */
+class Logger final : public hermes::Source
+{
+public:
+ Logger(hermes::Context *ctx);
+ ~Logger();
+};
+
+/**
+ * @brief Logger Configuration
+ *
+ * Users are able to turn logging on/off via EXO_LOG environment variable.
+ */
+class LoggerConfig final : public hermes::Config
+{
+public:
+ LoggerConfig();
+
+public:
+ void configure(const hermes::Source *, hermes::Source::Setting &) const final;
+ void configure(const Logger *, hermes::Source::Setting &) const;
+
+private:
+ bool _enabled;
+};
+
+} // namespace exo
+
+/**
+ * HOW TO USE:
+ *
+ * LOGGER(l);
+ *
+ * INFO(l) << "Hello, World" << std::endl;
+ *
+ */
+#define LOGGER(name) ::exo::Logger name{::exo::LoggingContext::get()};
+
+// TODO Support FATAL, ERROR, WARN, and VERBOSE
+#define INFO(name) HERMES_INFO(name)
+
+// WARNING!
+//
+// THE CURRENT IMPLEMENTATION IS NOT THREAD SAFE.
+//
+
+#endif // __LOG_H__
diff --git a/compiler/exo/src/LogHelper.cpp b/compiler/exo/src/LogHelper.cpp
new file mode 100644
index 000000000..7520b7ec8
--- /dev/null
+++ b/compiler/exo/src/LogHelper.cpp
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "LogHelper.h"
+
+namespace loco
+{
+
+std::ostream &operator<<(std::ostream &os, const loco::FeatureShape &feature_shape)
+{
+ os << "[" << feature_shape.count().value() << "," << feature_shape.height().value() << ","
+ << feature_shape.width().value() << "," << feature_shape.depth().value() << "]";
+ return os;
+}
+
+std::ostream &operator<<(std::ostream &os, const loco::FilterShape &filter_shape)
+{
+ os << "[" << filter_shape.height().value() << "," << filter_shape.width().value() << ","
+ << filter_shape.depth().value() << "," << filter_shape.count().value() << "]";
+ return os;
+}
+
+std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape)
+{
+ os << "[";
+ for (uint32_t r = 0; r < tensor_shape.rank(); ++r)
+ {
+ if (r)
+ os << ",";
+ os << tensor_shape.dim(r).value();
+ }
+ os << "]";
+ return os;
+}
+
+std::ostream &operator<<(std::ostream &os, const loco::Padding2D &pad)
+{
+ os << "[TLBR " << pad.top() << "," << pad.left() << "," << pad.bottom() << "," << pad.right()
+ << "]";
+
+ return os;
+}
+
+} // namespace loco
+
+std::ostream &operator<<(std::ostream &os, const std::vector<int64_t> &vi64)
+{
+ for (auto vi : vi64)
+ {
+ os << vi << " ";
+ }
+ return os;
+}
+
+#include "ExoFormattedGraph.h"
+
+namespace exo
+{
+
+FormattedGraph fmt(loco::Graph *g)
+{
+ auto node_summary_builder = stdex::make_unique<NodeSummaryBuilderFactory>();
+ return std::move(locop::fmt<locop::LinearV1>(g).with(std::move(node_summary_builder)));
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/LogHelper.h b/compiler/exo/src/LogHelper.h
new file mode 100644
index 000000000..69d81af9e
--- /dev/null
+++ b/compiler/exo/src/LogHelper.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOG_HELPER_H__
+#define __LOG_HELPER_H__
+
+#include <locop/FormattedGraph.h>
+
+#include <loco/IR/FeatureShape.h>
+#include <loco/IR/FilterShape.h>
+#include <loco/IR/TensorShape.h>
+
+#include <sstream>
+#include <vector>
+
+namespace loco
+{
+
+/**
+ * @brief dump FeatureShape values to stream
+ */
+std::ostream &operator<<(std::ostream &os, const loco::FeatureShape &feature_shape);
+
+/**
+ * @brief dump FilterShape values to stream
+ */
+std::ostream &operator<<(std::ostream &os, const loco::FilterShape &filter_shape);
+
+/**
+ * @brief dump TensorShape values to stream
+ */
+std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape);
+
+/**
+ * @brief dump Padding2D values to stream
+ */
+std::ostream &operator<<(std::ostream &os, const loco::Padding2D &pad);
+
+} // namespace loco
+
+/**
+ * @brief dump std::vector<int64_t> values to stream
+ */
+std::ostream &operator<<(std::ostream &os, const std::vector<int64_t> &vi64);
+
+namespace exo
+{
+
+using FormattedGraph = locop::FormattedGraphImpl<locop::Formatter::LinearV1>;
+
+FormattedGraph fmt(loco::Graph *g);
+
+static inline FormattedGraph fmt(const std::unique_ptr<loco::Graph> &g) { return fmt(g.get()); }
+
+} // namespace exo
+
+#endif // __LOG_HELPER_H__
diff --git a/compiler/exo/src/LoggingContext.cpp b/compiler/exo/src/LoggingContext.cpp
new file mode 100644
index 000000000..1c14d97b9
--- /dev/null
+++ b/compiler/exo/src/LoggingContext.cpp
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "exo/LoggingContext.h"
+#include "Log.h" // To use LoggerConfig
+
+#include <hermes/ConsoleReporter.h>
+#include <stdex/Memory.h>
+
+namespace exo
+{
+
+hermes::Context *LoggingContext::get(void)
+{
+ static hermes::Context *ctx = nullptr;
+
+ if (ctx == nullptr)
+ {
+ ctx = new hermes::Context;
+ ctx->sinks()->append(stdex::make_unique<hermes::ConsoleReporter>());
+ ctx->config(stdex::make_unique<LoggerConfig>());
+ }
+
+ return ctx;
+}
+
+} // namespac exo
diff --git a/compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp b/compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp
new file mode 100644
index 000000000..0fdcea939
--- /dev/null
+++ b/compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp
@@ -0,0 +1,116 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FoldReshapeOfConstPass.h"
+
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+#include <loco/Service/ShapeInference.h>
+
+#include <oops/InternalExn.h>
+
+namespace
+{
+
+/**
+ * @brief Check if node is TFLReshape and its input is TFLConst
+ * @return Casted TFLReshape for foldable candidate, nullptr otherwise
+ */
+locoex::TFLReshape *as_candidate(loco::Node *node)
+{
+ auto reshape = dynamic_cast<locoex::TFLReshape *>(node);
+ if (not reshape)
+ return nullptr;
+
+ // Only accept Constant input of Reshape
+ if (not dynamic_cast<locoex::TFLConst *>(reshape->tensor()))
+ return nullptr;
+
+ return reshape;
+}
+
+uint32_t volume(loco::Node *tensor_node)
+{
+ auto shape = loco::shape_get(tensor_node).as<loco::TensorShape>();
+
+ uint32_t vol = 1;
+ for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+ vol *= shape.dim(axis).value();
+
+ return vol;
+}
+
+void fold_reshape_of_const(locoex::TFLReshape *reshape)
+{
+ const loco::DataType FLOAT32 = loco::DataType::FLOAT32;
+
+ auto const_orig = dynamic_cast<locoex::TFLConst *>(reshape->tensor());
+
+ // Exceptions
+ {
+ EXO_ASSERT(const_orig, "Only support for Reshape-Const pair");
+ // TODO support other data types
+ if (const_orig->dtype() != FLOAT32)
+ INTERNAL_EXN_V("NYI for this data type", oops::to_uint32(const_orig->dtype()));
+
+ if (volume(const_orig) != volume(reshape))
+ INTERNAL_EXN("New shape of Reshape is not matched");
+ }
+
+ auto new_shape = loco::shape_get(reshape).as<loco::TensorShape>();
+
+ // TFLConst to replace
+ auto const_new = reshape->graph()->nodes()->create<locoex::TFLConst>();
+
+ const_new->dtype(FLOAT32);
+ const_new->rank(new_shape.rank());
+ const_new->size<FLOAT32>(const_orig->size<FLOAT32>());
+ for (uint32_t axis = 0; axis < new_shape.rank(); ++axis)
+ const_new->dim(axis) = new_shape.dim(axis);
+
+ for (uint32_t i = 0; i < const_new->size<FLOAT32>(); ++i)
+ {
+ const_new->at<FLOAT32>(i) = const_orig->at<FLOAT32>(i);
+ }
+
+ // replace
+ loco::replace(reshape).with(const_new);
+}
+
+} // namespace
+
+namespace exo
+{
+
+bool FoldReshapeOfConstPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto reshape = as_candidate(node))
+ {
+ fold_reshape_of_const(reshape);
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FoldReshapeOfConstPass.h b/compiler/exo/src/Pass/FoldReshapeOfConstPass.h
new file mode 100644
index 000000000..10f8004bf
--- /dev/null
+++ b/compiler/exo/src/Pass/FoldReshapeOfConstPass.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_FOLD_RESHAPE_OF_CONST_PASS_H__
+#define __PASS_FOLD_RESHAPE_OF_CONST_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse TFLReshape + TFLConst into one equivalent TFLConst
+ *
+ * <before>
+ * TFLConst --- TFLReshape --- Out
+ *
+ * <after>
+ * TFLConst --- TFLReshape ---
+ * TFLConst (new) ------------ Out
+ *
+ * TODO This pass is for temporary. Deprecate this pass.
+ */
+struct FoldReshapeOfConstPass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FoldReshapeOfConstPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __PASS_FOLD_RESHAPE_OF_CONST_PASS_H__
diff --git a/compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp b/compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp
new file mode 100644
index 000000000..005c42944
--- /dev/null
+++ b/compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp
@@ -0,0 +1,154 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FoldTransposeOfConstPass.h"
+
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+// TODO remove dependency to angkor
+#include <nncc/core/ADT/tensor/IndexEnumerator.h>
+#include <nncc/core/ADT/tensor/LexicalLayout.h>
+
+#include <oops/InternalExn.h>
+
+namespace
+{
+
+/**
+ * @brief Check if node is TFLTranspose and its input is TFLConst
+ * @return Casted TFLTranspose for foldable candidate, nullptr otherwise
+ */
+locoex::TFLTranspose *as_candidate(loco::Node *node)
+{
+ auto transpose = dynamic_cast<locoex::TFLTranspose *>(node);
+ if (not transpose)
+ return nullptr;
+
+ // Only accept Constant input of Transpose
+ if (not dynamic_cast<locoex::TFLConst *>(transpose->a()))
+ return nullptr;
+
+ // Only accept Constant permutation of Transpose
+ if (not dynamic_cast<locoex::TFLConst *>(transpose->perm()))
+ return nullptr;
+
+ return transpose;
+}
+
+nncc::core::ADT::tensor::Shape angkor_shape(locoex::TFLConst *node)
+{
+ nncc::core::ADT::tensor::Shape ret;
+
+ ret.resize(node->rank());
+ for (uint32_t axis = 0; axis < node->rank(); ++axis)
+ {
+ ret.dim(axis) = node->dim(axis).value();
+ }
+
+ return ret;
+}
+
+void fold_transpose_of_const(locoex::TFLTranspose *transpose)
+{
+ const loco::DataType FLOAT32 = loco::DataType::FLOAT32;
+ const loco::DataType S32 = loco::DataType::S32;
+
+ auto const_orig = dynamic_cast<locoex::TFLConst *>(transpose->a());
+ auto perm = dynamic_cast<locoex::TFLConst *>(transpose->perm());
+
+ // Exceptions
+ {
+ EXO_ASSERT(const_orig, "Only support for Transpose-Const pair");
+ // TODO support other data types
+ if (const_orig->dtype() != FLOAT32)
+ INTERNAL_EXN_V("NYI for this data type", oops::to_uint32(const_orig->dtype()));
+
+ EXO_ASSERT(perm, "Only support for constant permutation for Transpose");
+ // TODO support other data types
+ if (perm->dtype() != S32)
+ INTERNAL_EXN_V("NYI for this data type", oops::to_uint32(perm->dtype()));
+
+ auto okay = [&]() {
+ if (perm->rank() != 1)
+ return false;
+ if (perm->dim(0).value() != const_orig->rank())
+ return false;
+ return true;
+ };
+ if (not okay())
+ INTERNAL_EXN("Input and permutation for Transpose is not congruent");
+ }
+
+ uint32_t rank = const_orig->rank();
+
+ // TFLConst to replace
+ auto const_new = transpose->graph()->nodes()->create<locoex::TFLConst>();
+
+ const_new->dtype(FLOAT32);
+ const_new->rank(rank);
+ const_new->size<FLOAT32>(const_orig->size<FLOAT32>());
+ for (uint32_t axis = 0; axis < rank; ++axis)
+ const_new->dim(axis) = const_orig->dim(perm->at<S32>(axis)).value();
+
+ // TODO remove dependency to angkor
+ auto shape_orig = angkor_shape(const_orig);
+ auto shape_new = angkor_shape(const_new);
+
+ nncc::core::ADT::tensor::LexicalLayout l;
+ nncc::core::ADT::tensor::IndexEnumerator e{shape_new};
+
+ for (; e.valid(); e.advance())
+ {
+ loco::TensorIndex index_new = e.current();
+ loco::TensorIndex index_orig;
+
+ // Set original index from matching new index
+ index_orig.resize(rank);
+ for (uint32_t axis = 0; axis < rank; ++axis)
+ index_orig.at(perm->at<S32>(axis)) = index_new.at(axis);
+
+ const_new->at<FLOAT32>(l.offset(shape_new, index_new)) =
+ const_orig->at<FLOAT32>(l.offset(shape_orig, index_orig));
+ }
+
+ // replace
+ loco::replace(transpose).with(const_new);
+}
+
+} // namespace
+
+namespace exo
+{
+
+bool FoldTransposeOfConstPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto transpose = as_candidate(node))
+ {
+ fold_transpose_of_const(transpose);
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FoldTransposeOfConstPass.h b/compiler/exo/src/Pass/FoldTransposeOfConstPass.h
new file mode 100644
index 000000000..26656a118
--- /dev/null
+++ b/compiler/exo/src/Pass/FoldTransposeOfConstPass.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_FOLD_TRANSPOSE_OF_CONST_PASS_H__
+#define __PASS_FOLD_TRANSPOSE_OF_CONST_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse TFLTranspose + TFLConst into one equivalent TFLConst
+ *
+ * <before>
+ * TFLConst --- TFLTranspose --- Out
+ *
+ * <after>
+ * TFLConst --- TFLTranspose ---
+ * TFLConst (new) -------------- Out
+ *
+ * TODO This pass is for temporary. Deprecate this pass.
+ */
+struct FoldTransposeOfConstPass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FoldTransposeOfConstPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __PASS_FOLD_TRANSPOSE_OF_CONST_PASS_H__
diff --git a/compiler/exo/src/Pass/FuseBiasAddPass.cpp b/compiler/exo/src/Pass/FuseBiasAddPass.cpp
new file mode 100644
index 000000000..aab820995
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseBiasAddPass.cpp
@@ -0,0 +1,362 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseBiasAddPass.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+#include <loco/Service/TypeInference.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <oops/InternalExn.h>
+
+#include <set>
+
+/*
+ Note: Terms for variables in this implementation is as follows:
+
+ ex) subgraph handled: TFLConv2D -------- TFLAdd
+ (or TFLDepthwiseConv2D) (or TFLSub)
+ | |
+ \|/ \|/
+ variable name : former latter
+ Type : FormerT LatterT
+ (shortened name from Mixin) (template type)
+*/
+namespace
+{
+
+using FormerT = locoex::TFLNodeMixin<locoex::TFLNodeTrait::Bias>;
+
+loco::Node *as_loco_node(FormerT *former)
+{
+ auto loco_node = dynamic_cast<loco::Node *>(former);
+ assert(loco_node != nullptr);
+
+ return loco_node;
+}
+
+locoex::TFLConst *get_const(loco::Node *x, loco::Node *y)
+{
+ if (auto const_node = dynamic_cast<locoex::TFLConst *>(x))
+ return const_node;
+ else if (auto const_node = dynamic_cast<locoex::TFLConst *>(y))
+ return const_node;
+
+ return nullptr;
+}
+
+FormerT *get_former(loco::Node *x, loco::Node *y)
+{
+ if (auto node = dynamic_cast<FormerT *>(x))
+ return node;
+ else if (auto node = dynamic_cast<FormerT *>(y))
+ return node;
+
+ return nullptr;
+}
+
+/// @brief Finds input that is TFLConst and set it to new_input
+void set_const_input(locoex::TFLNode *node, locoex::TFLConst *new_input)
+{
+ if (auto add = dynamic_cast<locoex::TFLAdd *>(node))
+ {
+ if (dynamic_cast<locoex::TFLConst *>(add->x()))
+ add->x(new_input);
+ else if (dynamic_cast<locoex::TFLConst *>(add->y()))
+ add->y(new_input);
+ else
+ assert(false and "One node should be TFLConst");
+
+ return;
+ }
+
+ if (auto sub = dynamic_cast<locoex::TFLSub *>(node))
+ {
+ if (dynamic_cast<locoex::TFLConst *>(sub->x()))
+ sub->x(new_input);
+ else if (dynamic_cast<locoex::TFLConst *>(sub->y()))
+ sub->y(new_input);
+ else
+ assert(false and "One node should be TFLConst");
+
+ return;
+ }
+
+ assert(false and "Param should be TFLAdd or TFLSub");
+}
+
+/**
+ * @brief Creates a TFLConst whose shape is [to] and values are all const_node->at(0),
+ * where const_node has only one element(a scalar or a tensor of shape [1])
+ */
+locoex::TFLConst *create_widened(locoex::TFLConst *const_node, uint32_t to)
+{
+ auto const_shape = loco::shape_get(const_node).as<loco::TensorShape>();
+
+ assert(const_shape.rank() == 0 or (const_shape.rank() == 1 and const_shape.dim(0) == 1));
+
+ auto g = const_node->graph();
+
+ auto widened_const = g->nodes()->create<locoex::TFLConst>();
+ {
+ widened_const->dtype(loco::DataType::FLOAT32);
+ widened_const->rank(1);
+ widened_const->dim(0) = to;
+ widened_const->size<loco::DataType::FLOAT32>(to);
+ for (uint32_t x = 0; x < to; x++)
+ widened_const->at<loco::DataType::FLOAT32>(x) = const_node->at<loco::DataType::FLOAT32>(0);
+ }
+ return widened_const;
+}
+
+template <typename TFLType> float calc(float, float);
+
+template <> float calc<locoex::TFLAdd>(float x, float y) { return x + y; }
+template <> float calc<locoex::TFLSub>(float x, float y) { return x - y; }
+
+template <class LatterT> class Fuser
+{
+public:
+ Fuser(LatterT *latter)
+ {
+ static_assert(std::is_same<LatterT, locoex::TFLAdd>::value ||
+ std::is_same<LatterT, locoex::TFLSub>::value,
+ "wrong template type");
+
+ _latter = latter;
+ _graph = _latter->graph();
+ _const_node = get_const(_latter->x(), _latter->y());
+ _former = get_former(_latter->x(), _latter->y());
+
+ assert(_const_node && _former);
+ }
+
+ void fuse(void);
+
+private:
+ loco::Graph *_graph;
+ LatterT *_latter;
+ locoex::TFLConst *_const_node;
+ FormerT *_former;
+
+ locoex::TFLConst *create_fused_bias_const();
+};
+
+// instantiation
+template class Fuser<locoex::TFLAdd>;
+template class Fuser<locoex::TFLSub>;
+
+template <class LatterT> locoex::TFLConst *Fuser<LatterT>::create_fused_bias_const()
+{
+ // we have to create a new bias const by adding/substracting bias and const node (of TFLAdd or
+ // TFLSub)
+ auto bias = dynamic_cast<locoex::TFLConst *>(_former->bias());
+ assert(bias->dtype() == loco::DataType::FLOAT32 &&
+ _const_node->dtype() == loco::DataType::FLOAT32);
+
+ assert(bias->rank() == 1 && _const_node->rank() == 1);
+ assert(bias->dim(0) == _const_node->dim(0));
+
+ // build a new bias const
+ auto new_bias = _graph->nodes()->create<locoex::TFLConst>();
+ {
+ new_bias->dtype(loco::DataType::FLOAT32);
+
+ new_bias->rank(1);
+ new_bias->dim(0) = bias->dim(0);
+
+ new_bias->size<loco::DataType::FLOAT32>(bias->dim(0).value());
+
+ for (uint32_t x = 0; x < bias->dim(0).value(); x++)
+ new_bias->at<loco::DataType::FLOAT32>(x) = calc<LatterT>(
+ bias->at<loco::DataType::FLOAT32>(x), _const_node->at<loco::DataType::FLOAT32>(x));
+ }
+
+ return new_bias;
+}
+
+// FuseBiasAddPass works when former->fusedActivationFunction() == NONE
+bool check_act_func(FormerT *former)
+{
+ using FusedActFuncMixin = locoex::TFLNodeMixin<locoex::TFLNodeTrait::FusedActFunc>;
+
+ if (auto node = dynamic_cast<FusedActFuncMixin *>(former))
+ return node->fusedActivationFunction() == locoex::FusedActFunc::NONE;
+ else
+ return true;
+}
+
+template <class LatterT> void set_act_func(FormerT *former, LatterT *latter)
+{
+ using FusedActFuncMixin = locoex::TFLNodeMixin<locoex::TFLNodeTrait::FusedActFunc>;
+
+ if (auto node = dynamic_cast<FusedActFuncMixin *>(former))
+ node->fusedActivationFunction(latter->fusedActivationFunction());
+}
+
+// instantiation
+template void set_act_func(FormerT *, locoex::TFLAdd *);
+template void set_act_func(FormerT *, locoex::TFLSub *);
+
+/**
+ * @brief Fuse TFLAdd or TFLSub (latter) into TFLConv2d or TFLDepthwiseConv2D (former).
+ * All conditions should be checked before calling this.
+ *
+ * @note TFLAdd can have fused activation function (let's call this FAF for simplicity).
+ *
+ * Conv2D's FAF | TFLAdd's FAF => FAF after fusing TFLAdd into TFLConv2D
+ * ----------------|--------------- --------------------------------------
+ * NONE | NONE, RELU or RELU6 => TFLAdd's FAF
+ * other than NONE | anything => cannot be fused
+ */
+template <class LatterT> void Fuser<LatterT>::fuse(void)
+{
+ // check fused activation function
+ {
+ assert(check_act_func(_former));
+
+ set_act_func<LatterT>(_former, _latter);
+ }
+
+ auto new_bias = create_fused_bias_const();
+
+ // replace node with new_bias
+ // note that loco::replace() is not used because bias could be input of other op just in case
+ _former->bias(new_bias);
+
+ // remove TFLAdd or TFLSub node
+ loco::replace(_latter).with(as_loco_node(_former));
+ _latter->x(nullptr);
+ _latter->y(nullptr);
+}
+
+struct Collector final : public locoex::TFLNodeMutableVisitor<void>
+{
+ template <class LatterT>
+ void setCandidate(FormerT *former, LatterT *latter, locoex::TFLConst *const_node)
+ {
+ static_assert(std::is_same<LatterT, locoex::TFLAdd>::value ||
+ std::is_same<LatterT, locoex::TFLSub>::value,
+ "wrong template type");
+
+ if (!check_act_func(former))
+ return;
+
+ auto depth =
+ loco::shape_get(as_loco_node(former)).template as<loco::TensorShape>().dim(3).value();
+ auto const_shape = loco::shape_get(const_node).template as<loco::TensorShape>();
+
+ if (const_shape.rank() == 1 and const_shape.dim(0) == depth)
+ {
+ candidates.insert(latter);
+ }
+ // when Const has only one value, create a new const with shape [depth]
+ else if (const_shape.rank() == 0 or (const_shape.rank() == 1 and const_shape.dim(0) == 1))
+ {
+ if (!(loco::dtype_get(as_loco_node(former)) == loco::DataType::FLOAT32))
+ INTERNAL_EXN_V("Unsupported data type",
+ oops::to_uint32(loco::dtype_get(as_loco_node(former))));
+ if (!(const_node->dtype() == loco::DataType::FLOAT32))
+ INTERNAL_EXN_V("Unsupported data type", oops::to_uint32(const_node->dtype()));
+
+ auto new_bias_node = create_widened(const_node, depth);
+
+ // Replacing TFLConst input of TFLAdd or TFLSub.
+ // Note that calling loco::replace(const_node).with(new_bias_node) could be dangerous
+ // because const_node could be the input of many nodes
+ set_const_input(latter, new_bias_node);
+
+ candidates.insert(latter);
+ }
+ }
+
+ void visit(locoex::TFLAdd *latter) final
+ {
+ auto former = get_former(latter->x(), latter->y());
+ auto const_node = get_const(latter->x(), latter->y());
+
+ if (former && const_node)
+ setCandidate<locoex::TFLAdd>(former, latter, const_node);
+ }
+
+ void visit(locoex::TFLSub *latter) final
+ {
+ // TFLSub, of which x() = TFLConv2D or TFLDepthwiseConv2D, y() = TFLConst, is fusing target
+ auto former = dynamic_cast<FormerT *>(latter->x());
+ auto const_node = dynamic_cast<locoex::TFLConst *>(latter->y());
+
+ if (former && const_node)
+ setCandidate<locoex::TFLSub>(former, latter, const_node);
+ }
+
+ void visit(locoex::TFLNode *) final { return; }
+
+ std::set<locoex::TFLNode *> candidates;
+};
+
+struct Performer final : public locoex::TFLNodeMutableVisitor<void>
+{
+ void visit(locoex::TFLAdd *latter) final
+ {
+ assert(get_former(latter->x(), latter->y()));
+
+ Fuser<locoex::TFLAdd> fuser(latter);
+ fuser.fuse();
+ }
+
+ void visit(locoex::TFLSub *latter) final
+ {
+ assert(get_former(latter->x(), latter->y()));
+
+ Fuser<locoex::TFLSub> fuser(latter);
+ fuser.fuse();
+ }
+
+ void visit(locoex::TFLNode *) final { assert(false && "should not be called"); }
+};
+
+} // namespace
+
+namespace exo
+{
+
+bool FuseBiasAddPass::run(loco::Graph *g)
+{
+ Collector collector;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (node->dialect() == locoex::TFLDialect::get())
+ {
+ auto tfl_node = dynamic_cast<locoex::TFLNode *>(node);
+ tfl_node->accept(&collector);
+ }
+ }
+
+ Performer performer;
+
+ for (auto node : collector.candidates)
+ {
+ node->accept(&performer);
+ }
+
+ return collector.candidates.size() > 0;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FuseBiasAddPass.h b/compiler/exo/src/Pass/FuseBiasAddPass.h
new file mode 100644
index 000000000..68e624c6b
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseBiasAddPass.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_FUSE_BIASADD_PASS_H__
+#define __PASS_FUSE_BIASADD_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse TFLAdd or TFLSub into Bias input of the following ops:
+ * - TFLConv2D, TFLDepthwiseConv2D
+ * - TODO Consider to add FullyConnected, etc.
+ *
+ * Case 1. Conv2D and TFLAdd
+ *
+ * BEFORE:
+ *
+ * TFLConst A (a scalar or a tensor of shape [1] or [depth of TFLConv2D])
+ * |
+ * Foo -- TFLConv2D -- TFLAdd (or TFLSub) -- Bar
+ * |
+ * TFLConst B --+ (bias)
+ *
+ * AFTER:
+ * Foo ----- TFLConv2D ----- Bar
+ * |
+ * TFLConst A' --+ (bias)
+ *
+ * TFLConst B (dead node)
+ *
+ * TFLAdd (or TFLSub) (dead node)
+ *
+ * @note TFLSub, of which x() == TFLConv2D and y() == TFLConst, will be fused.
+ * If x() == TFLConst and y() == TFLConv2D, it won't be fused.
+ */
+struct FuseBiasAddPass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FuseBiasAddPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __PASS_FUSE_BIASADD_PASS_H__
diff --git a/compiler/exo/src/Pass/FuseBiasAddPass.test.cpp b/compiler/exo/src/Pass/FuseBiasAddPass.test.cpp
new file mode 100644
index 000000000..6ba728de0
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseBiasAddPass.test.cpp
@@ -0,0 +1,361 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseBiasAddPass.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "TestGraph.h"
+#include "TestHelper.h"
+
+#include <loco.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+void init(loco::Pull *pull)
+{
+ pull->dtype(loco::DataType::FLOAT32);
+ pull->shape({2, 3, 3, 2});
+}
+
+/// @brief Initializes TFLConv2D and related filter and bias
+void init(locoex::TFLConv2D *conv2d, locoex::TFLConst *filter, locoex::TFLConst *bias)
+{
+ // set conv2d
+ {
+ conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE);
+ conv2d->padding(locoex::Padding::VALID);
+ }
+
+ // set filter
+ {
+ filter->dtype(loco::DataType::FLOAT32);
+ filter->shape({2, 3, 3, 2});
+ filter->size<loco::DataType::FLOAT32>(2 * 3 * 3 * 2);
+
+ for (uint32_t x = 0; x < 2 * 3 * 3 * 2; x++)
+ filter->at<loco::DataType::FLOAT32>(x) = 0.0;
+ }
+
+ // set bias
+ {
+ bias->dtype(loco::DataType::FLOAT32);
+ bias->shape({2});
+ bias->size<loco::DataType::FLOAT32>(2);
+
+ for (uint32_t x = 0; x < 2; x++)
+ bias->at<loco::DataType::FLOAT32>(x) = 0.0;
+ }
+}
+
+template <class T> void init(T *node, locoex::FusedActFunc f)
+{
+ static_assert(std::is_same<T, locoex::TFLAdd>::value || std::is_same<T, locoex::TFLSub>::value,
+ "wrong template type");
+
+ node->fusedActivationFunction(f);
+}
+
+/// @brief Initializes one param of TFLAdd or TFLSub
+void init(locoex::TFLConst *addsub_param)
+{
+ // set addsub_param : y() value of TFLAdd or TFLSub
+ addsub_param->dtype(loco::DataType::FLOAT32);
+ addsub_param->shape({2});
+ addsub_param->size<loco::DataType::FLOAT32>(2);
+
+ for (uint32_t x = 0; x < 2; x++)
+ addsub_param->at<loco::DataType::FLOAT32>(x) = (x + 1) * 1.5; // 1.5, 3
+}
+
+} // namespace
+
+// A case when
+// - TFLConv2D has bias (0, 0)
+// - TFLAdd, of which x() or y() == TFLConv2D
+// - Another param of TFLAdd is TFLConst, (1.5, 3)
+//
+// After fusion, bias shold be (1.5, 3)
+TEST(FuseBiasAddPassTest, Conv2D_Add_01_basic)
+{
+ exo::test::TestGraph g;
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto add_y = g.append<locoex::TFLConst>();
+ auto add = g.append<locoex::TFLAdd>(conv2d, add_y);
+
+ g.complete(add);
+
+ init(g.pull);
+ init(conv2d, filter, bias);
+ init(add, locoex::FusedActFunc::NONE);
+ init(add_y);
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseBiasAddPass>();
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+
+ auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias());
+ ASSERT_TRUE(a_bias != nullptr);
+
+ ASSERT_TRUE(a_bias->dim(0) == 2);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0),
+ bias->at<loco::DataType::FLOAT32>(0) + add_y->at<loco::DataType::FLOAT32>(0));
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1),
+ bias->at<loco::DataType::FLOAT32>(1) + add_y->at<loco::DataType::FLOAT32>(1));
+}
+
+// A case when
+// - TFLConv2D has bias (0, 0)
+// - TFLAdd, of which x() or y() == TFLConv2D
+// - Another param of TFLAdd is TFLConst, (1.5) <-- scalar
+//
+// After fusion, bias shold be (1.5, 1.5)
+TEST(FuseBiasAddPassTest, Conv2D_Add_02_TFLAdd_y_is_scalar)
+{
+ exo::test::TestGraph g;
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto add_y = g.append<locoex::TFLConst>();
+ auto add = g.append<locoex::TFLAdd>(conv2d, add_y);
+
+ g.complete(add);
+
+ init(g.pull);
+ init(conv2d, filter, bias); // channel of conv2d is 2
+
+ {
+ // Size of this TFLConst is 1.
+ // Note that this should be widened later to the shape of [channel of Conv2D], which is [2]
+ add_y->dtype(loco::DataType::FLOAT32);
+ add_y->shape({1});
+ add_y->size<loco::DataType::FLOAT32>(1);
+ add_y->at<loco::DataType::FLOAT32>(0) = 1.5;
+ }
+ init(add, locoex::FusedActFunc::NONE);
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseBiasAddPass>();
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+
+ auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias());
+ ASSERT_TRUE(a_bias != nullptr);
+
+ ASSERT_TRUE(a_bias->dim(0) == 2);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0),
+ bias->at<loco::DataType::FLOAT32>(0) + 1.5);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1),
+ bias->at<loco::DataType::FLOAT32>(1) + 1.5);
+}
+
+// A case when
+// - TFLConv2D has bias (0, 0)
+// - TFLSub.x() == TFLConv2D
+// - TFLSub.y() == TFLConst, (1.5, 3)
+//
+// After fusion, bias shold be (-1.5, -3)
+TEST(FuseBiasAddPassTest, Conv2D_Sub_01_basic)
+{
+ exo::test::TestGraph g;
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto sub_y = g.append<locoex::TFLConst>();
+ auto sub = g.append<locoex::TFLSub>(conv2d, sub_y);
+
+ g.complete(sub);
+
+ init(g.pull);
+ init(conv2d, filter, bias);
+ init(sub, locoex::FusedActFunc::NONE);
+ init(sub_y);
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseBiasAddPass>();
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+
+ auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias());
+ ASSERT_TRUE(a_bias != nullptr);
+
+ ASSERT_TRUE(a_bias->dim(0) == 2);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0),
+ bias->at<loco::DataType::FLOAT32>(0) - sub_y->at<loco::DataType::FLOAT32>(0));
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1),
+ bias->at<loco::DataType::FLOAT32>(1) - sub_y->at<loco::DataType::FLOAT32>(1));
+}
+
+// A case when TFLConv2D is input of TFLSub but fusion cannot be performed.
+// - TFLSub.x() == TFLConst
+// - TFLSub.y() == TFLConv2D
+//
+// Here, TFLSub cannot be fused into TFLConst. To be fused, TFLSub.x() should be TFLConv2D and
+// TFLSub.y() should be TFLConst. So fusion will NOT happen.
+TEST(FuseBiasAddPassTest, Conv2D_Sub_02_fusing_will_not_performed)
+{
+ exo::test::TestGraph g;
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto sub_y = g.append<locoex::TFLConst>();
+ auto sub = g.append<locoex::TFLSub>(sub_y, conv2d); // This WON'T be fused
+
+ g.complete(sub);
+
+ init(g.pull);
+ init(conv2d, filter, bias);
+ init(sub, locoex::FusedActFunc::NONE);
+ init(sub_y);
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseBiasAddPass>();
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+
+ auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias());
+ ASSERT_TRUE(a_bias != nullptr);
+
+ ASSERT_TRUE(a_bias->dim(0) == 2);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0), 0);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1), 0);
+
+ auto a_sub = exo::test::find_first_node_bytype<locoex::TFLSub>(g.graph());
+ ASSERT_TRUE(a_sub != nullptr);
+ ASSERT_TRUE(a_sub->y() == a_conv2d); // Checking 'not-fused' state
+}
+
+// A case when
+// - TFLConv2D has an activation function with Relu
+// - TFLAdd, has no activation function
+//
+// No fusion should happen
+TEST(FuseBiasAddPassTest, Regression_Conv2D_Add_fused_action_00)
+{
+ exo::test::TestGraph g;
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto add_y = g.append<locoex::TFLConst>();
+ auto add = g.append<locoex::TFLAdd>(conv2d, add_y);
+
+ g.complete(add);
+
+ init(g.pull);
+ init(conv2d, filter, bias);
+ init(add, locoex::FusedActFunc::NONE);
+ init(add_y);
+
+ // Updating Fused Activation for this test
+ conv2d->fusedActivationFunction(locoex::FusedActFunc::RELU);
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseBiasAddPass>();
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+ ASSERT_TRUE(a_conv2d->fusedActivationFunction() == locoex::FusedActFunc::RELU);
+
+ auto an_add = exo::test::find_first_node_bytype<locoex::TFLAdd>(g.graph());
+ ASSERT_TRUE(an_add != nullptr);
+ ASSERT_TRUE(an_add->fusedActivationFunction() == locoex::FusedActFunc::NONE);
+
+ ASSERT_TRUE(an_add->x() == a_conv2d or an_add->y() == a_conv2d);
+}
+
+// A case when
+// - TFLConv2D has NONE activation function
+// - TFLAdd has Relu activation function
+//
+// TFLConv2D should have Relu activation function, TFLAdd is fused into bias input
+TEST(FuseBiasAddPassTest, Regression_Conv2D_Add_fused_action_01)
+{
+ exo::test::TestGraph g;
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto add_y = g.append<locoex::TFLConst>();
+ auto add = g.append<locoex::TFLAdd>(conv2d, add_y);
+
+ g.complete(add);
+
+ init(g.pull);
+ init(conv2d, filter, bias);
+ init(add, locoex::FusedActFunc::RELU);
+ init(add_y);
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseBiasAddPass>();
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+
+ auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias());
+ ASSERT_TRUE(a_bias != nullptr);
+
+ ASSERT_TRUE(a_bias->dim(0) == 2);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0),
+ bias->at<loco::DataType::FLOAT32>(0) + add_y->at<loco::DataType::FLOAT32>(0));
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1),
+ bias->at<loco::DataType::FLOAT32>(1) + add_y->at<loco::DataType::FLOAT32>(1));
+
+ ASSERT_TRUE(a_conv2d->fusedActivationFunction() == locoex::FusedActFunc::RELU);
+}
diff --git a/compiler/exo/src/Pass/FuseInstanceNormPass.cpp b/compiler/exo/src/Pass/FuseInstanceNormPass.cpp
new file mode 100644
index 000000000..04d4a62cd
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseInstanceNormPass.cpp
@@ -0,0 +1,402 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseInstanceNormPass.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/CircleNodes.h"
+
+#include <loco/Service/ShapeInference.h>
+
+#include <cassert>
+#include <set>
+
+// Helper to find commutative node's arguments
+namespace
+{
+
+/**
+ * INTRODUCTION
+ * Binary operation f(x,y) is 'commutative' when
+ * f(x,y) == f(y,x) holds for all x, y.
+ * For examples, ADD, MUL and SQUARED_DIFFERENCE are commutative.
+ * These helpers make it easy to find commutative arguemnts of commtative node.
+ *
+ * HOW TO USE
+ * COMM_NODE *node;
+ * ARG_TYPE_1 *arg1;
+ * ARG_TYPE_2 *arg2;
+ *
+ * bool ok = fill(&arg1, &arg2).with_commutative_args_of(node);
+ *
+ * Result
+ * If 'node's commutative argument types are actually {ARG_TYPE_1, ARG_TYPE_2}
+ * (as a set), 'arg1' and 'arg2' set as actual 'node's arguemnts with matching
+ * type, and return value 'ok' is true.
+ * Otherwise, 'arg1' and 'arg2' not changed, 'ok' is false.
+ */
+
+template <class ARG_TYPE_1, class ARG_TYPE_2> class NodeFiller final
+{
+public:
+ NodeFiller(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) : _arg_1(arg_1), _arg_2(arg_2)
+ {
+ // DO NOTHING
+ }
+
+ /**
+ * @return true When 'node's argument types are 'ARG_TYPE_1' and 'ARG_TYPE_2'
+ * In such case, it assign '_arg_1' and '_arg_2' to actual arguments
+ *
+ * @return false When 'node's argument types are NOT matched with 'ARG_TYPE_*'
+ * In such case, it does not amend '_arg_1' and '_arg_2'
+ *
+ * @require COMM_NODE has member x() and y()
+ */
+ template <class COMM_NODE> bool with_commutative_args_of(const COMM_NODE *node);
+
+private:
+ ARG_TYPE_1 **_arg_1;
+ ARG_TYPE_2 **_arg_2;
+};
+
+template <class ARG_TYPE_1, class ARG_TYPE_2>
+inline NodeFiller<ARG_TYPE_1, ARG_TYPE_2> fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
+{
+ return NodeFiller<ARG_TYPE_1, ARG_TYPE_2>{arg_1, arg_2};
+}
+
+template <class ARG_TYPE_1, class ARG_TYPE_2>
+template <class COMM_NODE>
+bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(const COMM_NODE *node)
+{
+ // Case 1) X == ARG_TYPE_1 / Y == ARG_TYPE_2
+ {
+ auto x = dynamic_cast<ARG_TYPE_1 *>(node->x());
+ auto y = dynamic_cast<ARG_TYPE_2 *>(node->y());
+
+ if (x && y)
+ {
+ *_arg_1 = x;
+ *_arg_2 = y;
+ return true;
+ }
+ }
+
+ // Case 2) X == ARG_TYPE_2 / Y == ARG_TYPE_1
+ {
+ auto x = dynamic_cast<ARG_TYPE_2 *>(node->x());
+ auto y = dynamic_cast<ARG_TYPE_1 *>(node->y());
+
+ if (x && y)
+ {
+ *_arg_1 = y;
+ *_arg_2 = x;
+ return true;
+ }
+ }
+
+ return false;
+}
+
+} // namespace
+
+// Helper to check detail
+namespace
+{
+
+/// @return true When node has shape of '1 x .. x 1 x depth'
+bool is_1D_with_dummy_dim(locoex::TFLConst *node, uint32_t depth)
+{
+ auto rank = node->rank();
+ uint32_t axis;
+ for (axis = 0; axis < rank - 1; ++axis)
+ {
+ if (node->dim(axis).value() != 1)
+ return false;
+ }
+ return node->dim(axis).value() == depth;
+}
+
+bool is_instance_mean(locoex::TFLMean *mean)
+{
+ //
+ // CHECK 1) input is rank 4
+ //
+ auto input = mean->input();
+ if (not loco::shape_known(input))
+ return false;
+ auto input_shape = loco::shape_get(input).as<loco::TensorShape>();
+ if (input_shape.rank() != 4)
+ return false;
+
+ //
+ // CHECK 2) 'reduction indices' is TFLConst of value [1,2], that is HW of NHWC
+ //
+ // TODO Support equivalent case, like [-3,-2]
+ // TODO Support non-Const case?
+ // TODO What if input is NCHW format in Circle?
+ auto red_indices = dynamic_cast<locoex::TFLConst *>(mean->reduction_indices());
+ if (not red_indices)
+ return false;
+ if (red_indices->rank() != 1)
+ return false;
+ std::set<int32_t> red_indices_set;
+ {
+ // TODO Currently only support S32, support other types
+ assert(red_indices->dtype() == loco::DataType::S32);
+ for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
+ red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
+ }
+ if (red_indices_set.size() != 2)
+ return false;
+ if (red_indices_set.find(1) == red_indices_set.end())
+ return false;
+ if (red_indices_set.find(2) == red_indices_set.end())
+ return false;
+
+ //
+ // CHECK 3) keep_dims == true (?)
+ //
+ // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
+ // TODO Check this fact, and if true, return true regardless of keep_dims
+ return mean->keep_dims();
+}
+
+} // namespace
+
+// Helper to fuse Instance Norm
+namespace
+{
+
+/**
+ * SUBGRAPH PATTERN
+ *
+ * - Below diagram shows Instance Norm pattern to fuse.
+ * - Execution dependency order is top to the bottom.
+ * - Node name is matched with variable name of InstanceNormPattern class.
+ * - Usually, first word of node name (variable name) is node type. For e.g.
+ * variable 'mean_as_variance' is pointer to TFLMean.
+ * - (Item in parenthesis) means actually exist, but not having a name and
+ * not a variable of InstanceNormPattern class.
+ *
+ * TODO support other semantically same patterns for instance norm
+ *
+ * [In]
+ * |
+ * V
+ * +----------- ifm -----+ (reduction indicies)
+ * | | | |
+ * | | V V
+ * | | mean_of_ifm ----------------+
+ * | V | |
+ * | sqdiff <--+ (reduction indicies) |
+ * | | | |
+ * | V | |
+ * | mean_as_variance <---+ const_as_epsilon |
+ * | | | |
+ * | V | |
+ * | add_as_variance <--------+ |
+ * | | |
+ * | V |
+ * | rsqrt const_as_gamma |
+ * | | | |
+ * | V | |
+ * | mul_gamma <--+ |
+ * | | | |
+ * V V V |
+ * mul_as_scaled_ifm mul_as_scaled_mean <-------------+
+ * | |
+ * | const_as_beta |
+ * | | V
+ * | +------> sub
+ * V |
+ * add_as_terminal <----------+
+ * |
+ * V
+ * [Out]
+ */
+class InstanceNormPattern final
+{
+public:
+ InstanceNormPattern(locoex::TFLAdd *candidate)
+ {
+ assert(candidate);
+ add_as_terminal = candidate;
+ }
+
+public:
+ bool matched();
+ bool matched() const { return _matched; }
+
+public:
+ // Context
+ loco::Node *ifm = nullptr;
+ locoex::TFLMean *mean_of_ifm = nullptr;
+ locoex::TFLSquaredDifference *sqdiff = nullptr;
+ locoex::TFLMean *mean_as_variance = nullptr;
+ locoex::TFLConst *const_as_epsilon = nullptr;
+ locoex::TFLAdd *add_as_variance = nullptr;
+ locoex::TFLRsqrt *rsqrt = nullptr;
+ locoex::TFLConst *const_as_gamma = nullptr;
+ locoex::TFLMul *mul_gamma = nullptr;
+ locoex::TFLMul *mul_as_scaled_ifm = nullptr;
+ locoex::TFLMul *mul_as_scaled_mean = nullptr;
+ locoex::TFLConst *const_as_beta = nullptr;
+ locoex::TFLSub *sub = nullptr;
+ locoex::TFLAdd *add_as_terminal = nullptr;
+
+private:
+ bool _matched = false;
+};
+
+bool InstanceNormPattern::matched()
+{
+ if (_matched)
+ return true;
+
+#define CHECK_OR_FALSE(condition) \
+ if (not(condition)) \
+ return false;
+
+ // Check order is DFS
+
+ CHECK_OR_FALSE(fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
+ CHECK_OR_FALSE(fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
+
+ CHECK_OR_FALSE(loco::shape_known(ifm));
+ auto ifm_shape = loco::shape_get(ifm);
+ CHECK_OR_FALSE(ifm_shape.domain() == loco::Domain::Tensor);
+ auto ifm_tensor_shape = ifm_shape.as<loco::TensorShape>();
+ CHECK_OR_FALSE(ifm_tensor_shape.rank() == 4);
+ uint32_t ifm_channel_depth = ifm_tensor_shape.dim(3).value();
+
+ CHECK_OR_FALSE(fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
+ CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
+
+ add_as_variance = dynamic_cast<locoex::TFLAdd *>(rsqrt->x());
+ CHECK_OR_FALSE(add_as_variance);
+
+ CHECK_OR_FALSE(
+ fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
+
+ CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
+ // TODO Support regarding broadcast
+ CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
+
+ CHECK_OR_FALSE(is_instance_mean(mean_as_variance));
+ sqdiff = dynamic_cast<locoex::TFLSquaredDifference *>(mean_as_variance->input());
+ CHECK_OR_FALSE(sqdiff);
+
+ loco::Node *ifm_should_be = nullptr;
+ CHECK_OR_FALSE(fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
+ CHECK_OR_FALSE(ifm == ifm_should_be);
+ CHECK_OR_FALSE(is_instance_mean(mean_of_ifm));
+ CHECK_OR_FALSE(ifm == mean_of_ifm->input());
+
+ const_as_beta = dynamic_cast<locoex::TFLConst *>(sub->x());
+ CHECK_OR_FALSE(const_as_beta);
+ CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
+
+ mul_as_scaled_mean = dynamic_cast<locoex::TFLMul *>(sub->y());
+ CHECK_OR_FALSE(mul_as_scaled_mean);
+
+ locoex::TFLMul *mul_gamma_should_be = nullptr;
+ locoex::TFLMean *mean_of_ifm_should_be = nullptr;
+ CHECK_OR_FALSE(fill(&mul_gamma_should_be, &mean_of_ifm_should_be)
+ .with_commutative_args_of(mul_as_scaled_mean));
+ CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
+ CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
+#undef CHECK_OR_FALSE
+ _matched = true;
+ return true;
+}
+
+/**
+ * Instance norm pattern would be fused like following diagram:
+ *
+ * [In] --------------------------- CircleInstanceNorm --- [Out]
+ * / /
+ * const_as_gamma --- TFLReshape --- /
+ * /
+ * const_as_beta ---- TFLReshape ---
+ *
+ * Note
+ * - 'const_as_gamma' and 'const_as_beta' are from original graph
+ * - Value of 'const_as_epsilon' would be copied to CircleInstanceNorm's attribute
+ * - TFLReshape is added as CircleInstanceNorm only accept 1D tensor
+ * - 'TFLConst --- TFLReshape' is expected to be fused in constant folding for Reshape
+ */
+void fuse_instance_norm(const InstanceNormPattern &p)
+{
+ assert(p.matched());
+
+ auto graph = p.add_as_terminal->graph();
+
+ // Make reshape for gamma & beta
+ auto reshape_gamma = graph->nodes()->create<locoex::TFLReshape>();
+ auto reshape_beta = graph->nodes()->create<locoex::TFLReshape>();
+ {
+ auto ifm_shape = loco::shape_get(p.ifm).as<loco::TensorShape>();
+ uint32_t ifm_channel_depth = ifm_shape.dim(3).value();
+
+ int32_t new_shape[1] = {static_cast<int32_t>(ifm_channel_depth)};
+
+ reshape_gamma->tensor(p.const_as_gamma);
+ reshape_beta->tensor(p.const_as_beta);
+
+ locoex::set_new_shape(reshape_gamma, new_shape, 1);
+ locoex::set_new_shape(reshape_beta, new_shape, 1);
+ }
+
+ // Make Instance Norm to replace
+ auto instance_norm = graph->nodes()->create<locoex::CircleInstanceNorm>();
+ instance_norm->input(p.ifm);
+ instance_norm->gamma(reshape_gamma);
+ instance_norm->beta(reshape_beta);
+ float epsilon = p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
+ instance_norm->epsilon(epsilon);
+ instance_norm->fusedActivationFunction(p.add_as_terminal->fusedActivationFunction());
+
+ replace(p.add_as_terminal).with(instance_norm);
+}
+
+} // namespace
+
+namespace exo
+{
+
+bool FuseInstanceNormPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto add = dynamic_cast<locoex::TFLAdd *>(node);
+ if (not add)
+ continue;
+
+ InstanceNormPattern pattern(add);
+ if (not pattern.matched())
+ continue;
+
+ fuse_instance_norm(pattern);
+ changed = true;
+ }
+
+ return changed;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FuseInstanceNormPass.h b/compiler/exo/src/Pass/FuseInstanceNormPass.h
new file mode 100644
index 000000000..e6361021c
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseInstanceNormPass.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __FUSE_INSTANCE_NORM_PASS_H__
+#define __FUSE_INSTANCE_NORM_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse certain pattern of subgraph into CircleInstanceNorm
+ * with auxiliary nodes
+ *
+ * For detailed subgraph pattern to be fused, please check its implementation.
+ */
+struct FuseInstanceNormPass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FuseInstanceNormPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __FUSE_INSTANCE_NORM_PASS_H__
diff --git a/compiler/exo/src/Pass/FuseReluPass.cpp b/compiler/exo/src/Pass/FuseReluPass.cpp
new file mode 100644
index 000000000..d7af0c506
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseReluPass.cpp
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseReluPass.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+#include <set>
+
+namespace
+{
+
+bool is_pred_fusable(loco::Node *node)
+{
+ using namespace locoex;
+
+ auto fusable_node = dynamic_cast<TFLNodeMixin<TFLNodeTrait::FusedActFunc> *>(node);
+
+ return (fusable_node and fusable_node->fusedActivationFunction() == FusedActFunc::NONE);
+};
+
+struct Collector final : public locoex::TFLNodeMutableVisitor<void>
+{
+ void visit(locoex::TFLRelu *node) final
+ {
+ if (is_pred_fusable(node->features()))
+ candidates.insert(node);
+ }
+
+ void visit(locoex::TFLRelu6 *node) final
+ {
+ if (is_pred_fusable(node->features()))
+ candidates.insert(node);
+ }
+
+ void visit(locoex::TFLNode *) final { return; }
+
+ std::set<locoex::TFLNode *> candidates;
+};
+
+void set_activation_fusion(loco::Node *node, locoex::FusedActFunc f)
+{
+ using namespace locoex;
+
+ if (auto fusable_node = dynamic_cast<TFLNodeMixin<TFLNodeTrait::FusedActFunc> *>(node))
+ fusable_node->fusedActivationFunction(f);
+ else
+ assert(false);
+}
+
+struct Performer final : public locoex::TFLNodeMutableVisitor<void>
+{
+ void visit(locoex::TFLRelu *the_relu) final
+ {
+ set_activation_fusion(the_relu->features(), locoex::FusedActFunc::RELU);
+
+ loco::replace(the_relu).with(the_relu->features());
+ the_relu->features(nullptr);
+ }
+
+ void visit(locoex::TFLRelu6 *the_relu6) final
+ {
+ set_activation_fusion(the_relu6->features(), locoex::FusedActFunc::RELU6);
+
+ loco::replace(the_relu6).with(the_relu6->features());
+ the_relu6->features(nullptr);
+ }
+
+ void visit(locoex::TFLNode *) final { assert(false && "should not be called"); }
+};
+
+} // namespace
+
+namespace exo
+{
+
+bool FuseReluPass::run(loco::Graph *g)
+{
+ Collector collector;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (node->dialect() == locoex::TFLDialect::get())
+ {
+ auto tfl_node = dynamic_cast<locoex::TFLNode *>(node);
+ tfl_node->accept(&collector);
+ }
+ }
+
+ Performer performer;
+
+ for (auto node : collector.candidates)
+ {
+ node->accept(&performer);
+ }
+
+ return collector.candidates.size() > 0;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FuseReluPass.h b/compiler/exo/src/Pass/FuseReluPass.h
new file mode 100644
index 000000000..1cd276b29
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseReluPass.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_FUSE_RELU_PASS_H__
+#define __PASS_FUSE_RELU_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse TFLRelu or TFLRelu6 into the TensorFlow Lite ops below:
+ *
+ * ADD, AVERAGE_POOL_2D, CONCATENATION, CONV_2D, DEPTHWISE_CONV_2D,
+ * FULLY_CONNECTED, L2_NORMALIZATION, L2_POOL_2D, MAX_POOL_2D, MUL
+ */
+struct FuseReluPass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FuseReluPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __PASS_FUSE_RELU_PASS_H__
diff --git a/compiler/exo/src/Pass/FuseReluPass.test.cpp b/compiler/exo/src/Pass/FuseReluPass.test.cpp
new file mode 100644
index 000000000..6f83d4dd0
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseReluPass.test.cpp
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseReluPass.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "TestGraph.h"
+
+#include <loco.h>
+#include <logo/RemoveDeadNodePass.h>
+
+#include <gtest/gtest.h>
+
+#include <type_traits> // for std::is_same
+
+namespace
+{
+
+void init(loco::Pull *pull)
+{
+ pull->dtype(loco::DataType::FLOAT32);
+ pull->shape({2, 3, 3, 2});
+}
+
+/// @brief Initializes TFLConv2D and related filter and bias
+void init(locoex::TFLConv2D *conv2d, locoex::TFLConst *filter, locoex::TFLConst *bias)
+{
+ // set conv2d
+ {
+ conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE);
+ conv2d->padding(locoex::Padding::VALID);
+ }
+
+ // set filter
+ {
+ filter->dtype(loco::DataType::FLOAT32);
+ filter->shape({2, 3, 3, 2});
+ filter->size<loco::DataType::FLOAT32>(2 * 3 * 3 * 2);
+
+ for (uint32_t x = 0; x < 2 * 3 * 3 * 2; x++)
+ filter->at<loco::DataType::FLOAT32>(x) = 0.0;
+ }
+
+ // set bias
+ {
+ bias->dtype(loco::DataType::FLOAT32);
+ bias->shape({2});
+ bias->size<loco::DataType::FLOAT32>(2);
+
+ for (uint32_t x = 0; x < 2; x++)
+ bias->at<loco::DataType::FLOAT32>(x) = 0.0;
+ }
+}
+
+} // namespace
+
+/// Test code called by TEST(..)
+/// This tests whether Conv2D - FusedTFLType is fused.
+template <class FusedTFLType, locoex::FusedActFunc FusedActFunc> void test()
+{
+ static_assert((std::is_same<FusedTFLType, locoex::TFLRelu>::value &&
+ FusedActFunc == locoex::FusedActFunc::RELU) ||
+ (std::is_same<FusedTFLType, locoex::TFLRelu6>::value &&
+ FusedActFunc == locoex::FusedActFunc::RELU6),
+ "wrong template type");
+
+ exo::test::TestGraph g;
+ {
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto fusable_node = g.append<FusedTFLType>(conv2d);
+
+ g.complete(fusable_node);
+
+ init(g.pull);
+ init(conv2d, filter, bias);
+ }
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseReluPass>();
+ test_phase.add_pass<logo::RemoveDeadNodePass>(); // to remove TFLRelu
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+ ASSERT_TRUE(a_conv2d->fusedActivationFunction() == FusedActFunc);
+
+ auto removed_fusable_node = exo::test::find_first_node_bytype<FusedTFLType>(g.graph());
+ ASSERT_TRUE(removed_fusable_node == nullptr);
+}
+
+// A case with Conv2D-Relu
+TEST(FuseReluTest, Conv2D_Relu_basic) { test<locoex::TFLRelu, locoex::FusedActFunc::RELU>(); }
+
+// A case with Conv2D-Relu6
+TEST(FuseReluTest, Conv2D_Relu6_basic) { test<locoex::TFLRelu6, locoex::FusedActFunc::RELU6>(); }
diff --git a/compiler/exo/src/Pass/FuseRsqrtPass.cpp b/compiler/exo/src/Pass/FuseRsqrtPass.cpp
new file mode 100644
index 000000000..08d704139
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseRsqrtPass.cpp
@@ -0,0 +1,95 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseRsqrtPass.h"
+
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+namespace
+{
+
+/**
+ * @return Casted TFLDiv for fusable candidate, nullptr otherwise
+ *
+ * This helper checkes fusability with following conditions:
+ * - TFLDiv has no activation
+ * - TFLDiv's first argument is TFLConst with all value 1
+ * - TFLDiv's second argument is TFLSqrt
+ */
+locoex::TFLDiv *as_candidate(loco::Node *node)
+{
+ auto div = dynamic_cast<locoex::TFLDiv *>(node);
+ if (not div)
+ return nullptr;
+
+ // Cannot fuse Div with activation function
+ if (div->fusedActivationFunction() != locoex::FusedActFunc::NONE)
+ return nullptr;
+
+ auto const_one = dynamic_cast<locoex::TFLConst *>(div->x());
+ if (not const_one)
+ return nullptr;
+
+ const loco::DataType FLOAT32 = loco::DataType::FLOAT32;
+ // TODO Support other dtype
+ EXO_ASSERT(const_one->dtype() == FLOAT32, "Only support FLOAT32 now");
+ for (uint32_t i = 0; i < const_one->size<FLOAT32>(); ++i)
+ if (const_one->at<FLOAT32>(i) != 1.0f)
+ return nullptr;
+
+ auto sqrt = dynamic_cast<locoex::TFLSqrt *>(div->y());
+ if (not sqrt)
+ return nullptr;
+
+ return div;
+}
+
+void fuse_rsqrt(locoex::TFLDiv *div)
+{
+ auto sqrt = dynamic_cast<locoex::TFLSqrt *>(div->y());
+ EXO_ASSERT(sqrt, "sqrt should be valid at this point");
+
+ // TFLRsqrt to replace
+ auto rsqrt = div->graph()->nodes()->create<locoex::TFLRsqrt>();
+ rsqrt->x(sqrt->x());
+
+ // replace
+ loco::replace(div).with(rsqrt);
+}
+
+} // namespace
+
+namespace exo
+{
+
+bool FuseRsqrtPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto div = as_candidate(node))
+ {
+ fuse_rsqrt(div);
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FuseRsqrtPass.h b/compiler/exo/src/Pass/FuseRsqrtPass.h
new file mode 100644
index 000000000..1e60e4a49
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseRsqrtPass.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __FUSE_RSQRT_PASS_H__
+#define __FUSE_RSQRT_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse TFLSqrt that is divided(TFLDiv) by 1, into TFLRsqrt
+ *
+ * <BEFORE>
+ *
+ * TFLConst(1) ------
+ * \
+ * A --- TFLSqrt --- TFLDiv --- B
+ *
+ * <AFTER>
+ *
+ * A --- TFLRsqrt --- B
+ */
+struct FuseRsqrtPass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FuseRsqrtPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __FUSE_RSQRT_PASS_H__
diff --git a/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp b/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp
new file mode 100644
index 000000000..3f985a505
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseSquaredDifferencePass.h"
+
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+namespace
+{
+
+/**
+ * @return Casted TFLMul for fusable candidate, nullptr otherwise
+ *
+ * This helper checkes fusability with following conditions:
+ * - TFLMul has no activation
+ * - TFLMul's first and second arguments are equal and TFLSub
+ */
+locoex::TFLMul *as_candidate(loco::Node *node)
+{
+ auto mul = dynamic_cast<locoex::TFLMul *>(node);
+ if (not mul)
+ return nullptr;
+
+ // Cannot fuse mul with activation function
+ if (mul->fusedActivationFunction() != locoex::FusedActFunc::NONE)
+ return nullptr;
+
+ if (mul->x() != mul->y())
+ return nullptr;
+
+ if (not dynamic_cast<locoex::TFLSub *>(mul->x()))
+ return nullptr;
+
+ return mul;
+}
+
+void fuse_squared_difference(locoex::TFLMul *mul)
+{
+ auto sub = dynamic_cast<locoex::TFLSub *>(mul->x());
+ EXO_ASSERT(sub, "sub should be valid at this point");
+
+ // TFLSquaredDifference to replace
+ auto sq_diff = mul->graph()->nodes()->create<locoex::TFLSquaredDifference>();
+ sq_diff->x(sub->x());
+ sq_diff->y(sub->y());
+
+ // replace
+ loco::replace(mul).with(sq_diff);
+}
+
+} // namespace
+
+namespace exo
+{
+
+bool FuseSquaredDifferencePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto mul = as_candidate(node))
+ {
+ fuse_squared_difference(mul);
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FuseSquaredDifferencePass.h b/compiler/exo/src/Pass/FuseSquaredDifferencePass.h
new file mode 100644
index 000000000..dbc15149f
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseSquaredDifferencePass.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __FUSE_SQUARED_DIFFERENCE_PASS_H__
+#define __FUSE_SQUARED_DIFFERENCE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse SquaredDifference pattern
+ *
+ * <BEFORE>
+ *
+ * A --- TFLSub --- TFLMul --- C
+ * / \ /
+ * B ---- -----
+ *
+ * <AFTER>
+ *
+ * A --- TFLSquaredDifference --- C
+ * /
+ * B ----
+ */
+struct FuseSquaredDifferencePass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FuseSquaredDifferencePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __FUSE_SQUARED_DIFFERENCE_PASS_H__
diff --git a/compiler/exo/src/Pass/MergeConcatNodesPass.cpp b/compiler/exo/src/Pass/MergeConcatNodesPass.cpp
new file mode 100644
index 000000000..8945fcfce
--- /dev/null
+++ b/compiler/exo/src/Pass/MergeConcatNodesPass.cpp
@@ -0,0 +1,191 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MergeConcatNodesPass.h"
+#include "Dialect/IR/TFLNodes.h"
+
+#include <oops/InternalExn.h>
+
+#include <vector>
+
+namespace
+{
+
+bool canMerge(locoex::TFLConcatenation *node1, locoex::TFLConcatenation *node2)
+{
+ if (node1->fusedActivationFunction() != node2->fusedActivationFunction())
+ return false;
+
+ if (node1->axis() != node2->axis())
+ return false;
+
+ switch (node1->fusedActivationFunction())
+ {
+ case locoex::FusedActFunc::NONE:
+ case locoex::FusedActFunc::RELU:
+ case locoex::FusedActFunc::RELU6:
+ return true;
+
+ // case locoex::FusedActFunc::TANH:
+ // return false;
+
+ default:
+ INTERNAL_EXN_V("Unknown FusedActFunc", oops::to_uint32(node1->fusedActivationFunction()));
+ }
+}
+
+/**
+ * @brief Collect all the inputs of newly created TFLConcatenation nodes
+ *
+ * in:0 -------------------------------\
+ * in:1 ---- TFLConcatenation:0 -------- TFLConcatenation:3 --- C
+ * (axis = 0, NONE) (axis = 0, NONE)
+ * in:2 ---/ /
+ * in:3 ---- TFLConcatenation:1 ------/
+ * (axis = 1, NONE) /
+ * in:4 ---/ /
+ * in:5 ---- TFLConcatenation:2 ---/
+ * (axis = 0, RELU)
+ * in:6 ---/
+ *
+ * For exmaple, if graph is like above, dfs(TFLConcatenation:3) will
+ * return [in:0, in:1, in:2, TFLConcatenation:1, TFLConcatenation:2]
+ *
+ * TFLConcatenation:0 can be merged to TFLConcatenation:3,
+ * because axis and fusedActivationFunction are same.
+ * It means that [in:1, in:2] will be linked as inputs of new TFLConcatenation.
+ *
+ * However, TFLConcatenation:1 and TFLConcatenation:2 cannot be merged to
+ * TFLConcatenation:3 because axis and fusedActivationFunction of each are different.
+ * So [in:3, in:4, in:5, in:6] will not be linked as inputs of new TFLConcatenation
+ * and [TFLConcatenation:1, TFLConcatenation:2] will be linked instead.
+ *
+ * Therefore, inputs of newly created TFLConcatenation node for merging
+ * TFLConcatenation:3 will be [in:0, in:1, in:2, TFLConcatenation:1, TFLConcatenation:2]
+ * and dfs(TFLConcatenation:3) will return it.
+ *
+ *
+ * @note The input nodes should be traversed by LRV,
+ * which is from left to right (input:0 --> input:N)
+ */
+std::vector<loco::Node *> dfs(locoex::TFLConcatenation *root)
+{
+ std::vector<loco::Node *> res;
+
+ for (uint32_t i = 0; i < root->numValues(); ++i)
+ {
+ auto input = dynamic_cast<locoex::TFLConcatenation *>(root->values(i));
+ if (input != nullptr && canMerge(input, root))
+ {
+ auto children = dfs(input);
+ for (auto child : children)
+ res.push_back(child);
+ }
+ else
+ {
+ res.push_back(root->values(i));
+ }
+ }
+
+ return res;
+}
+
+} // namespace
+
+namespace exo
+{
+
+/**
+ * @brief Merge TFLConcatenate nodes whose axis and fusedActivationFunction are same
+ *
+ * [Before]
+ * in:0 -------------------------------\
+ * in:1 ---- TFLConcatenation:0 -------- TFLConcatenation:3 --- C
+ * (axis = 0, NONE) (axis = 0, NONE)
+ * in:2 ---/ /
+ * in:3 ---- TFLConcatenation:1 ------/
+ * (axis = 1, NONE) /
+ * in:4 ---/ /
+ * in:5 ---- TFLConcatenation:2 ---/
+ * (axis = 0, RELU)
+ * in:6 ---/
+ *
+ * [After]
+ * in:0 -------------------------------\
+ * in:1 -------------------------------- TFLConcatenation:4 --- C
+ * (axis = 0, NONE)
+ * in:2 -------------------------------/
+ * in:3 ---- TFLConcatenation:1 ------/
+ * (axis = 1, NONE) /
+ * in:4 ---/ /
+ * in:5 ---- TFLConcatenation:2 ---/
+ * (axis = 0, RELU)
+ * in:6 ---/
+ *
+ *
+ * in:1 ---- TFLConcatenation:0 ----
+ * (axis = 0, NONE)
+ * in:2 ---/
+ *
+ *
+ * ---- TFLConcatenation:3 ----
+ * (axis = 0, NONE)
+ */
+bool MergeConcatNodesPass::run(loco::Graph *graph)
+{
+ // Let's enumerate nodes required to compute output nodes
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+
+ // Find TFLConcatenation nodes which have another TFLConcatenation nodes
+ // as inputs, with same axis and same fusedActivationFunction
+ std::vector<locoex::TFLConcatenation *> candidates;
+ for (auto node : active_nodes)
+ {
+ if (auto concat = dynamic_cast<locoex::TFLConcatenation *>(node))
+ {
+ for (uint32_t i = 0; i < concat->numValues(); ++i)
+ {
+ auto input = dynamic_cast<locoex::TFLConcatenation *>(concat->values(i));
+ if (input != nullptr && canMerge(input, concat))
+ {
+ candidates.push_back(concat);
+ break;
+ }
+ }
+ }
+ }
+
+ // Merge multiple TFLConcatenation nodes as one TFLConcatenation node
+ for (auto node : candidates)
+ {
+ auto inputs = dfs(node);
+
+ auto new_concat = graph->nodes()->create<locoex::TFLConcatenation>(inputs.size());
+ new_concat->axis(node->axis());
+ new_concat->fusedActivationFunction(node->fusedActivationFunction());
+
+ for (uint32_t i = 0; i < inputs.size(); ++i)
+ new_concat->values(i, inputs.at(i));
+
+ loco::replace(node).with(new_concat);
+ for (uint32_t i = 0; i < node->numValues(); ++i)
+ node->values(i, nullptr);
+ }
+
+ return candidates.size() > 0;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/MergeConcatNodesPass.h b/compiler/exo/src/Pass/MergeConcatNodesPass.h
new file mode 100644
index 000000000..823214f43
--- /dev/null
+++ b/compiler/exo/src/Pass/MergeConcatNodesPass.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_MERGE_CONCAT_NODES_H__
+#define __PASS_MERGE_CONCAT_NODES_H__
+
+#include <loco.h>
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Merge concat nodes whose axis and fusedActivationFunction are same
+ *
+ */
+class MergeConcatNodesPass : public logo::Pass
+{
+public:
+ virtual const char *name(void) const { return "exo::MergeConcatNodesPass"; }
+
+public:
+ bool run(loco::Graph *graph);
+};
+
+} // namespace exo
+
+#endif // __PASS_MERGE_CONCAT_NODES_H__
diff --git a/compiler/exo/src/Pass/ShapeInferencePass.cpp b/compiler/exo/src/Pass/ShapeInferencePass.cpp
new file mode 100644
index 000000000..bc60f91c4
--- /dev/null
+++ b/compiler/exo/src/Pass/ShapeInferencePass.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ShapeInferencePass.h"
+
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/Service/TFLShapeInferenceRule.h"
+
+#include "Dialect/IR/CircleDialect.h"
+#include "Dialect/Service/CircleShapeInferenceRule.h"
+
+#include <loco.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/Service/CanonicalShapeInferenceRule.h>
+#include <loco/Service/ShapeInference.h>
+#include <loco/Service/MultiDialectShapeInferenceRule.h>
+
+#include <locoex/COpDialect.h>
+#include <locoex/Service/COpShapeInferenceRule.h>
+
+namespace exo
+{
+
+/**
+ * @note Currently, TFL and Circle backend share this inference. However, TFL
+ * backend does not require rule for Circle dialect.
+ * TODO Make dedicated inference pass for Circle Dialect.
+ */
+bool ShapeInferencePass::run(loco::Graph *g)
+{
+ loco::CanonicalShapeInferenceRule canonical_rule;
+ locoex::TFLShapeInferenceRule tfl_rule;
+ locoex::CircleShapeInferenceRule circle_rule;
+ locoex::COpShapeInferenceRule cop_rule;
+
+ loco::MultiDialectShapeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(locoex::TFLDialect::get(), &tfl_rule)
+ .bind(locoex::CircleDialect::get(), &circle_rule)
+ .bind(locoex::COpDialect::get(), &cop_rule);
+
+ return loco::apply(&rules).to(g);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/ShapeInferencePass.h b/compiler/exo/src/Pass/ShapeInferencePass.h
new file mode 100644
index 000000000..518c87403
--- /dev/null
+++ b/compiler/exo/src/Pass/ShapeInferencePass.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_SHAPE_INFERENCE_PASS_H__
+#define __PASS_SHAPE_INFERENCE_PASS_H__
+
+#include <loco.h>
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Pass to infer shape of nodes
+ */
+class ShapeInferencePass : public logo::Pass
+{
+public:
+ virtual const char *name(void) const { return "exo::ShapeInferencePass"; }
+
+public:
+ bool run(loco::Graph *graph);
+};
+
+} // namespace exo
+
+#endif //__PASS_SHAPE_INFERENCE_PASS_H__
diff --git a/compiler/exo/src/Pass/TypeInferencePass.cpp b/compiler/exo/src/Pass/TypeInferencePass.cpp
new file mode 100644
index 000000000..31d4f13b6
--- /dev/null
+++ b/compiler/exo/src/Pass/TypeInferencePass.cpp
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TypeInferencePass.h"
+
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/Service/TFLTypeInferenceRule.h"
+
+#include "Dialect/IR/CircleDialect.h"
+#include "Dialect/Service/CircleTypeInferenceRule.h"
+
+#include <loco.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/Service/TypeInference.h>
+
+#include <locoex/COpDialect.h>
+#include <locoex/Service/COpTypeInference.h>
+
+namespace exo
+{
+
+/**
+ * @note Currently, TFL and Circle backend share this inference. However, TFL
+ * backend does not require rule for Circle dialect.
+ * TODO Make dedicated inference pass for Circle Dialect.
+ */
+bool TypeInferencePass::run(loco::Graph *g)
+{
+ loco::CanonicalTypeInferenceRule canonical_rule;
+ locoex::TFLTypeInferenceRule tfl_rule;
+ locoex::CircleTypeInferenceRule circle_rule;
+ locoex::COpTypeInferenceRule cop_rule;
+
+ loco::MultiDialectTypeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(locoex::TFLDialect::get(), &tfl_rule)
+ .bind(locoex::CircleDialect::get(), &circle_rule)
+ .bind(locoex::COpDialect::get(), &cop_rule);
+
+ return loco::apply(&rules).to(g);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/TypeInferencePass.h b/compiler/exo/src/Pass/TypeInferencePass.h
new file mode 100644
index 000000000..3ede587a0
--- /dev/null
+++ b/compiler/exo/src/Pass/TypeInferencePass.h
@@ -0,0 +1,42 @@
+
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_TYPE_INFERENCE_PASS_H__
+#define __PASS_TYPE_INFERENCE_PASS_H__
+
+#include <loco.h>
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Pass to infer type of nodes
+ */
+class TypeInferencePass : public logo::Pass
+{
+public:
+ virtual const char *name(void) const { return "exo::TypeInferencePass"; }
+
+public:
+ bool run(loco::Graph *graph);
+};
+
+} // namespace exo
+
+#endif //__PASS_TYPE_INFERENCE_PASS_H__
diff --git a/compiler/exo/src/Passes.cpp b/compiler/exo/src/Passes.cpp
new file mode 100644
index 000000000..99d229c9c
--- /dev/null
+++ b/compiler/exo/src/Passes.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Passes.h"
+
+// This file is to make sure that Passes.h be compiled
diff --git a/compiler/exo/src/Passes.h b/compiler/exo/src/Passes.h
new file mode 100644
index 000000000..2a702d01d
--- /dev/null
+++ b/compiler/exo/src/Passes.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASSES_H__
+#define __PASSES_H__
+
+// Please add in alphabetical order
+// Please append 'Pass' suffix to Pass class and file names
+
+#include "Pass/FoldReshapeOfConstPass.h"
+#include "Pass/FoldTransposeOfConstPass.h"
+#include "Pass/FuseBiasAddPass.h"
+#include "Pass/FuseInstanceNormPass.h"
+#include "Pass/FuseReluPass.h"
+#include "Pass/FuseRsqrtPass.h"
+#include "Pass/FuseSquaredDifferencePass.h"
+#include "Pass/MergeConcatNodesPass.h"
+#include "Pass/ShapeInferencePass.h"
+#include "Pass/TypeInferencePass.h"
+
+#include <logo/RemoveDeadNodePass.h>
+#include <logo/RemoveForwardNodePass.h>
+#include <logo/SimplifyDomainConversionPass.h>
+
+#endif // __PASSES_H__
diff --git a/compiler/exo/src/ProgressReporter.cpp b/compiler/exo/src/ProgressReporter.cpp
new file mode 100644
index 000000000..ff919dae8
--- /dev/null
+++ b/compiler/exo/src/ProgressReporter.cpp
@@ -0,0 +1,84 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ProgressReporter.h"
+
+#include "Log.h"
+#include "LogHelper.h"
+
+#include <logo/Phase.h>
+#include <logo/Pass.h>
+
+#include <cassert>
+
+namespace
+{
+
+char to_char(bool b) { return b ? 'Y' : 'N'; }
+
+const char *to_str(logo::PhaseStrategy s)
+{
+ switch (s)
+ {
+ case logo::PhaseStrategy::Saturate:
+ return "Saturate";
+ case logo::PhaseStrategy::Restart:
+ return "Restart";
+ }
+ assert(false);
+ return "";
+}
+
+} // namespace
+
+namespace exo
+{
+
+void ProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseBegin> *)
+{
+ LOGGER(prime);
+
+ INFO(prime) << "==============================================================";
+ INFO(prime) << "exo::PhaseRunner<" << to_str(strategy()) << ">";
+ INFO(prime) << "Initial graph";
+ INFO(prime) << fmt(graph());
+}
+
+void ProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseEnd> *)
+{
+ LOGGER(prime);
+
+ INFO(prime) << "exo::PhaseRunner<" << to_str(strategy()) << "> - done";
+}
+
+void ProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassBegin> *info)
+{
+ LOGGER(prime);
+
+ INFO(prime) << "--------------------------------------------------------------";
+ INFO(prime) << "Before " << logo::pass_name(info->pass());
+}
+
+void ProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassEnd> *info)
+{
+ LOGGER(prime);
+
+ INFO(prime) << "After " << logo::pass_name(info->pass())
+ << " (changed: " << to_char(info->changed()) << ")";
+ INFO(prime) << fmt(graph());
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/ProgressReporter.h b/compiler/exo/src/ProgressReporter.h
new file mode 100644
index 000000000..b0f420df9
--- /dev/null
+++ b/compiler/exo/src/ProgressReporter.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PROGRESSREPORTER_H__
+#define __PROGRESSREPORTER_H__
+
+#include <logo/Phase.h>
+
+#include <loco.h>
+
+namespace exo
+{
+
+class ProgressReporter : public logo::PhaseEventListener
+{
+public:
+ ProgressReporter(loco::Graph *graph, logo::PhaseStrategy strategy)
+ : _graph{graph}, _strategy{strategy}
+ {
+ // DO NOTHING
+ }
+
+public:
+ void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseBegin> *) override;
+ void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseEnd> *) override;
+ void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassBegin> *) override;
+ void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassEnd> *) override;
+
+public:
+ loco::Graph *graph(void) const { return _graph; }
+ logo::PhaseStrategy strategy(void) const { return _strategy; }
+
+private:
+ loco::Graph *_graph;
+ logo::PhaseStrategy _strategy;
+};
+
+} // namespace exo
+
+#endif // __PROGRESSREPORTER_H__
diff --git a/compiler/exo/src/ShapeInference.cpp b/compiler/exo/src/ShapeInference.cpp
new file mode 100644
index 000000000..bceb1495f
--- /dev/null
+++ b/compiler/exo/src/ShapeInference.cpp
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ShapeInference.h"
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/Service/TFLShapeInferenceRule.h"
+
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+#include <loco/Service/ShapeInference.h>
+#include <loco/Service/CanonicalShapeInferenceRule.h>
+#include <loco/Service/MultiDialectShapeInferenceRule.h>
+
+#include <locoex/COpCall.h>
+#include <locoex/COpDialect.h>
+#include <locoex/Service/COpShapeInferenceRule.h>
+
+namespace exo
+{
+
+ShapeDescription ShapeInference::get(loco::Node *node)
+{
+ // TODO Adjust indentation level
+ {
+ assert(loco::shape_known(node));
+ return to_shape_description(loco::shape_get(node));
+ }
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/ShapeInference.h b/compiler/exo/src/ShapeInference.h
new file mode 100644
index 000000000..ec141ccfc
--- /dev/null
+++ b/compiler/exo/src/ShapeInference.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __SHAPE_INFERENCE_H__
+#define __SHAPE_INFERENCE_H__
+
+#include "ExporterUtils.h"
+
+#include <loco/IR/Nodes.h>
+
+namespace exo
+{
+
+/**
+ * @brief Get the shape of each node as a node annotation
+ *
+ * HOW TO USE
+ *
+ * ShapeInference::get(g->nodes()->at(..));
+ */
+struct ShapeInference
+{
+ static ShapeDescription get(loco::Node *node);
+};
+
+} // namespace exo
+
+#endif // __SHAPE_INFERENCE_H__
diff --git a/compiler/exo/src/TFLite/TFLExporter.cpp b/compiler/exo/src/TFLite/TFLExporter.cpp
new file mode 100644
index 000000000..cf002b3e1
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLExporter.cpp
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "exo/TFLExporter.h"
+
+#include "TFLExporterImpl.h"
+
+#include <stdex/Memory.h>
+
+#include <oops/InternalExn.h>
+
+#include <fstream>
+
+namespace exo
+{
+
+TFLExporter::TFLExporter(loco::Graph *graph) : _impl(stdex::make_unique<Impl>(graph))
+{
+ // NOTHING TO DO
+}
+
+TFLExporter::~TFLExporter() = default;
+
+void TFLExporter::dumpToFile(const char *path) const
+{
+ const char *ptr = _impl->getBufferPointer();
+ const size_t size = _impl->getBufferSize();
+
+ if (!ptr)
+ INTERNAL_EXN("Graph was not serialized by FlatBuffer for some reason");
+
+ std::ofstream file(path, std::ofstream::binary);
+ file.write(ptr, size);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/TFLite/TFLExporterImpl.cpp b/compiler/exo/src/TFLite/TFLExporterImpl.cpp
new file mode 100644
index 000000000..07adbfb9d
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLExporterImpl.cpp
@@ -0,0 +1,179 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLExporterImpl.h"
+
+#include "Convert.h"
+#include "ExoOptimize.h"
+
+#include "TFLTensorExporter.h"
+#include "TFLOperationExporter.h"
+#include "TFLExporterUtils.h"
+
+#include "Log.h"
+#include "Knob.h"
+
+#include <oops/InternalExn.h>
+
+#include <cassert>
+#include <unordered_map>
+#include <string>
+#include <stdexcept>
+
+namespace
+{
+
+using namespace exo;
+using namespace exo::tflite_detail;
+
+void registerGraphInputTensors(loco::Graph *graph, SubGraphContext &ctx)
+{
+ for (uint32_t n = 0; n < graph->inputs()->size(); ++n)
+ {
+ auto node = loco::pull_node(graph, n);
+ assert(node != nullptr);
+ ctx._inputs.push_back(get_tensor_index(node));
+ }
+}
+
+void registerGraphOutputTensors(loco::Graph *graph, SubGraphContext &ctx)
+{
+ for (uint32_t n = 0; n < graph->outputs()->size(); ++n)
+ {
+ auto push = loco::push_node(graph, n);
+ assert(push != nullptr);
+ auto node = push->from();
+ assert(node != nullptr);
+ ctx._outputs.push_back(get_tensor_index(node));
+ }
+}
+
+} // namespace
+
+namespace
+{
+using namespace tflite;
+using namespace flatbuffers;
+
+Offset<Vector<Offset<OperatorCode>>>
+encodeOperatorCodes(FlatBufferBuilder &builder, std::unordered_map<OpCode, uint32_t> &opcodes,
+ std::unordered_map<OpCode, std::string> &custom_opcodes)
+{
+ std::vector<Offset<OperatorCode>> operator_codes_vec(opcodes.size());
+ for (auto it : opcodes)
+ {
+ uint32_t idx = it.second;
+ if (it.first.opcode != BuiltinOperator_CUSTOM)
+ {
+ operator_codes_vec[idx] = CreateOperatorCode(builder, it.first.opcode);
+ }
+ else // custom op
+ {
+ auto opCode = it.first;
+ auto custom_code = custom_opcodes.find(opCode);
+ if (custom_code == custom_opcodes.end())
+ INTERNAL_EXN("Cannot find code for custom op");
+
+ operator_codes_vec[idx] =
+ CreateOperatorCode(builder, it.first.opcode, builder.CreateString(custom_code->second));
+ }
+ }
+ return builder.CreateVector(operator_codes_vec);
+}
+
+} // namespace
+
+namespace exo
+{
+
+using namespace exo::tflite_detail;
+using namespace tflite;
+using namespace flatbuffers;
+
+TFLExporter::Impl::Impl(loco::Graph *graph) { exportGraph(graph); }
+
+::flatbuffers::Offset<::tflite::SubGraph> TFLExporter::Impl::exportSubgraph(SerializedModelData &gd)
+{
+ auto tensors = _builder.CreateVector(gd._tensors);
+ auto inputs = _builder.CreateVector(gd._inputs);
+ auto outputs = _builder.CreateVector(gd._outputs);
+ auto operators = _builder.CreateVector(gd._operators);
+ auto subgraph = CreateSubGraph(_builder, tensors, inputs, outputs, operators);
+ return subgraph;
+}
+
+void TFLExporter::Impl::exportGraph(loco::Graph *graph)
+{
+ LOGGER(l);
+
+ // IR-level conversion and optimization
+ {
+ convert_to_TFLNodes(graph);
+ set(Dialect::TFLITE);
+ optimize(graph);
+ }
+
+ _builder.Clear();
+
+ SerializedModelData gd;
+
+ // This version is taken from comment in fbs
+ constexpr uint32_t version = 3;
+
+ registerGraphIOName(graph, gd);
+
+ // parse graph into SerializedModelData structure
+ exportOpDefinedTensors(graph, _builder, gd);
+
+ // NOTE Invoke these register functions only after each node is annotated with its tensor_index
+ registerGraphInputTensors(graph, gd);
+ registerGraphOutputTensors(graph, gd);
+
+ exportNodes(graph, _builder, gd);
+
+ // encode operator codes
+ auto operator_codes =
+ encodeOperatorCodes(_builder, gd._operator_codes, gd._custom_operator_codes);
+
+ // Subgraphs
+ Offset<SubGraph> subgraph = exportSubgraph(gd);
+ auto subgraphs = _builder.CreateVector(std::vector<Offset<SubGraph>>{subgraph});
+
+ // Description
+ std::string description_str = "nnpackage";
+ auto description = _builder.CreateString(description_str);
+
+ // create array of buffers
+ auto buffers = _builder.CreateVector(gd._buffers);
+
+ // empty metadata
+ std::vector<int> metadata_buffer_vec;
+ auto metadata_buffer = _builder.CreateVector(metadata_buffer_vec);
+
+ // Model
+ auto model_offset = CreateModel(_builder, version, operator_codes, subgraphs, description,
+ buffers, metadata_buffer);
+ FinishModelBuffer(_builder, model_offset);
+}
+
+const char *TFLExporter::Impl::getBufferPointer() const
+{
+ return reinterpret_cast<const char *>(_builder.GetBufferPointer());
+}
+
+size_t TFLExporter::Impl::getBufferSize() const { return _builder.GetSize(); }
+
+} // namespace exo
diff --git a/compiler/exo/src/TFLite/TFLExporterImpl.h b/compiler/exo/src/TFLite/TFLExporterImpl.h
new file mode 100644
index 000000000..01c549a43
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLExporterImpl.h
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TFL_EXPORTER_IMPL_H__
+#define __TFL_EXPORTER_IMPL_H__
+
+#include "exo/TFLExporter.h"
+#include "schema_generated.h"
+
+#include <loco.h>
+
+namespace exo
+{
+
+namespace tflite_detail
+{
+
+struct SerializedModelData;
+
+} // namespace tflite_detail
+
+using namespace tflite_detail;
+
+/**
+ * internal implementation of interface exporter class
+ */
+class TFLExporter::Impl
+{
+public:
+ Impl() = delete;
+ ~Impl() = default;
+
+ explicit Impl(loco::Graph *graph);
+
+ /**
+ * @return pointer to buffer with serialized graph
+ */
+ const char *getBufferPointer() const;
+
+ /**
+ * @return size of buffer with serialized graph
+ */
+ size_t getBufferSize() const;
+
+private:
+ /**
+ * @brief create Subgraph using data stored in SerializedModelData
+ * @param gd information about serializer parts of model
+ * @return offset in buffer corresponding to serialized subgraph
+ */
+ flatbuffers::Offset<tflite::SubGraph> exportSubgraph(SerializedModelData &gd);
+
+ /**
+ * @brief root function that writes graph into internal buffer
+ * @param graph
+ */
+ void exportGraph(loco::Graph *graph);
+
+private:
+ flatbuffers::FlatBufferBuilder _builder;
+};
+
+} // namespace exo
+
+#endif // __TFL_EXPORTER_IMPL_H__
diff --git a/compiler/exo/src/TFLite/TFLExporterImpl.test.cpp b/compiler/exo/src/TFLite/TFLExporterImpl.test.cpp
new file mode 100644
index 000000000..7d74223c5
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLExporterImpl.test.cpp
@@ -0,0 +1,413 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLExporterImpl.h"
+
+#include "schema_generated.h"
+
+#include "TestGraph.h"
+#include "GraphBlock.h"
+#include "Knob.h"
+
+#include <loco/IR/PermutingCodec.h>
+#include <stdex/Memory.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+class TFLExporterImplTests : public ::testing::Test
+{
+public:
+ TFLExporterImplTests() { _graph = loco::make_graph(); }
+
+public:
+ virtual ~TFLExporterImplTests() = default;
+
+protected:
+ loco::Graph *graph(void) { return _graph.get(); }
+
+ template <typename NodeT> NodeT *make_node(void);
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+template <typename NodeT> NodeT *TFLExporterImplTests::make_node(void)
+{
+ return graph()->nodes()->create<NodeT>();
+}
+
+template <> loco::FeatureEncode *TFLExporterImplTests::make_node(void)
+{
+ loco::FeatureEncode *encode_layer = graph()->nodes()->create<loco::FeatureEncode>();
+
+ auto encoder = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+ (*encoder->perm())[loco::FeatureAxis::Count] = 0;
+ (*encoder->perm())[loco::FeatureAxis::Depth] = 1;
+ (*encoder->perm())[loco::FeatureAxis::Height] = 2;
+ (*encoder->perm())[loco::FeatureAxis::Width] = 3;
+ encode_layer->encoder(std::move(encoder));
+
+ return encode_layer;
+}
+
+template <> loco::FeatureDecode *TFLExporterImplTests::make_node(void)
+{
+ loco::FeatureDecode *decode_layer = graph()->nodes()->create<loco::FeatureDecode>();
+
+ auto decoder = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+ (*decoder->perm())[loco::FeatureAxis::Count] = 0;
+ (*decoder->perm())[loco::FeatureAxis::Depth] = 1;
+ (*decoder->perm())[loco::FeatureAxis::Height] = 2;
+ (*decoder->perm())[loco::FeatureAxis::Width] = 3;
+ decode_layer->decoder(std::move(decoder));
+
+ return decode_layer;
+}
+
+} // namespace
+
+// TODO TFLAdd
+
+// TODO TFLAveragePool2D
+
+TEST_F(TFLExporterImplTests, Concatenate)
+{
+ auto pull1 = make_node<loco::Pull>();
+ {
+ pull1->dtype(loco::DataType::FLOAT32);
+ pull1->shape({1, 2, 3, 4});
+ }
+ auto pull2 = make_node<loco::Pull>();
+ {
+ pull2->dtype(loco::DataType::FLOAT32);
+ pull2->shape({1, 2, 3, 4});
+ }
+ auto concat = make_node<loco::TensorConcat>();
+ {
+ concat->lhs(pull1);
+ concat->rhs(pull2);
+ }
+ auto push = make_node<loco::Push>();
+ {
+ push->from(concat);
+ }
+
+ auto input1 = graph()->inputs()->create();
+ {
+ input1->name("input1");
+ loco::link(input1, pull1);
+ }
+ auto input2 = graph()->inputs()->create();
+ {
+ input2->name("input2");
+ loco::link(input2, pull2);
+ }
+ auto output = graph()->outputs()->create();
+ {
+ output->name("output");
+ loco::link(output, push);
+ }
+
+ exo::TFLExporter::Impl exporter{graph()};
+
+ // TODO Add more checks
+ SUCCEED();
+}
+
+// TODO TFLConv2D
+
+// TODO TFLDepthwiseConv2D
+
+// TODO TFLDiv
+
+// TODO TFLMaxPool2D
+
+// TODO TFLMul
+
+TEST_F(TFLExporterImplTests, Relu6)
+{
+ auto pull = make_node<loco::Pull>();
+ {
+ pull->dtype(loco::DataType::FLOAT32);
+ pull->shape({1, 8, 8, 3});
+ }
+ auto relu6 = make_node<loco::ReLU6>();
+ {
+ relu6->input(pull);
+ }
+ auto push = make_node<loco::Push>();
+ {
+ push->from(relu6);
+ }
+
+ auto input = graph()->inputs()->create();
+ {
+ input->name("input");
+ loco::link(input, pull);
+ }
+ auto output = graph()->outputs()->create();
+ {
+ output->name("output");
+ loco::link(output, push);
+ }
+
+ exo::TFLExporter::Impl exporter{graph()};
+
+ // TODO Add more checks
+ SUCCEED();
+}
+
+// TODO TFLRelu6
+
+// TODO TFLReshape
+
+// TODO TFLSoftmax
+
+// TODO TFLSqrt
+
+// TODO TFLSub
+
+// TODO TFLTanh
+
+TEST(TFLExporterImplTest, Transpose_simple)
+{
+ exo::test::ExampleGraph<exo::test::ExampleGraphType::Transpose> g;
+
+ // pull attribute
+ {
+ g.pull->dtype(loco::DataType::FLOAT32);
+ g.pull->shape({1, 2, 2, 3});
+ }
+
+ // transpose attribute
+ {
+ g.transpose->perm()->size(4);
+ g.transpose->perm()->axis(0) = 1;
+ g.transpose->perm()->axis(1) = 2;
+ g.transpose->perm()->axis(2) = 3;
+ g.transpose->perm()->axis(3) = 0;
+ }
+
+ exo::TFLExporter::Impl exporter{g.graph()};
+ {
+ auto model = tflite::GetModel(exporter.getBufferPointer());
+ auto operators = model->subgraphs()->Get(0)->operators();
+
+ assert(operators->Length() == 1);
+
+ int n = 0; // op index of Transpose in tflite file
+
+ auto opcode_index = operators->Get(n)->opcode_index();
+
+ ASSERT_EQ(model->operator_codes()->Get(opcode_index)->builtin_code(),
+ tflite::BuiltinOperator_TRANSPOSE);
+
+ auto perm = operators->Get(n)->inputs()->Get(1);
+
+ auto perm_tensor = model->subgraphs()->Get(0)->tensors()->Get(perm);
+ ASSERT_EQ(perm_tensor->type(), tflite::TensorType::TensorType_INT32);
+ ASSERT_EQ(perm_tensor->shape()->size(), 1);
+ ASSERT_EQ(perm_tensor->shape()->Get(0), 4);
+
+ auto bufs = (model->buffers());
+ auto *perm_buf =
+ reinterpret_cast<const int32_t *>(bufs->Get(perm_tensor->buffer())->data()->data());
+
+ ASSERT_EQ(perm_buf[0], 1);
+ ASSERT_EQ(perm_buf[1], 2);
+ ASSERT_EQ(perm_buf[2], 3);
+ ASSERT_EQ(perm_buf[3], 0);
+ }
+}
+
+/*
+ test case:
+ Pull ----- FeatureEncode ---- FeatureDecode --- Push
+ 0 -----------> H ---------+ O 0
+ 1 W +----> H -----------> 1
+ 2 I(depth) W 2
+ 3 O(coutn) I 3
+
+ axis 0 ----------> H --------------> H -----------> 1
+ axis 1 ----------> W --------------> W -----------> 2
+ axis 2 ----------> I --------------> I -----------> 3
+ axis 3 ----------> O --------------> O -----------> 0
+
+ So, perm vector of Tranpose = [3, 0, 1, 2].
+ Please refer to loco::TensorTranspose about the definition of perm vector.
+*/
+TEST(TFLExporterImplTest, Transpose_from_FilterEncode_FilterDecode)
+{
+ exo::test::ExampleGraph<exo::test::ExampleGraphType::FilterEncode_FilterDecode> g;
+
+ // pull attribute
+ {
+ g.pull->dtype(loco::DataType::FLOAT32);
+ g.pull->shape({1, 2, 3, 4}); // whatever value of rank 4
+ }
+
+ exo::TFLExporter::Impl exporter{g.graph()};
+ {
+ auto model = tflite::GetModel(exporter.getBufferPointer());
+ auto operators = model->subgraphs()->Get(0)->operators();
+
+ assert(operators->Length() == 1);
+
+ int n = 0; // op index of Transpose in tflite file
+
+ auto opcode_index = operators->Get(n)->opcode_index();
+
+ ASSERT_EQ(model->operator_codes()->Get(opcode_index)->builtin_code(),
+ tflite::BuiltinOperator_TRANSPOSE);
+
+ auto perm = operators->Get(n)->inputs()->Get(1);
+
+ auto perm_tensor = model->subgraphs()->Get(0)->tensors()->Get(perm);
+ ASSERT_EQ(perm_tensor->type(), tflite::TensorType::TensorType_INT32);
+ ASSERT_EQ(perm_tensor->shape()->size(), 1);
+ ASSERT_EQ(perm_tensor->shape()->Get(0), 4);
+
+ auto bufs = (model->buffers());
+ auto *perm_buf =
+ reinterpret_cast<const int32_t *>(bufs->Get(perm_tensor->buffer())->data()->data());
+ ASSERT_EQ(perm_buf[0], 3);
+ ASSERT_EQ(perm_buf[1], 0);
+ ASSERT_EQ(perm_buf[2], 1);
+ ASSERT_EQ(perm_buf[3], 2);
+ }
+}
+
+/**
+ * What happens when there is a mismatch between generation and execution order!?
+ */
+TEST_F(TFLExporterImplTests, Regression_0000)
+{
+ // This test was written without considering fusion.
+ // For this reason, this check is needed.
+ // TODO Rewrite this test
+ if (exo::get<exo::Knob::UseFuseReluPass>())
+ return;
+
+ // Execution Order: MaxPool2D -> ReLU
+ // Generation Order: ReLU -> MaxPool2D
+ auto pull = make_node<loco::Pull>();
+ {
+ pull->dtype(loco::DataType::FLOAT32);
+ pull->shape({1, 8, 8, 3});
+ }
+ auto relu = make_node<loco::ReLU>();
+ auto encode = exo::make_feature_encode<exo::FeatureLayout::NHWC>(pull);
+ auto maxpool = make_node<loco::MaxPool2D>();
+ auto decode = exo::make_feature_decode<exo::FeatureLayout::NHWC>(relu);
+ auto push = make_node<loco::Push>();
+
+ ASSERT_EQ(maxpool->window()->vertical(), 1);
+ ASSERT_EQ(maxpool->window()->horizontal(), 1);
+
+ maxpool->ifm(encode);
+ relu->input(maxpool);
+ push->from(decode);
+
+ auto input = graph()->inputs()->create();
+ {
+ input->name("input");
+ loco::link(input, pull);
+ }
+ auto output = graph()->outputs()->create();
+ {
+ output->name("output");
+ loco::link(output, push);
+ }
+
+ exo::TFLExporter::Impl exporter{graph()};
+ {
+ int64_t maxpool_execution_index = -1;
+ int64_t relu_exeuction_index = -1;
+
+ auto model = tflite::GetModel(exporter.getBufferPointer());
+ auto operators = model->subgraphs()->Get(0)->operators();
+
+ for (uint32_t n = 0; n < operators->Length(); ++n)
+ {
+ auto opcode_index = operators->Get(n)->opcode_index();
+
+ switch (model->operator_codes()->Get(opcode_index)->builtin_code())
+ {
+ case tflite::BuiltinOperator_RELU:
+ ASSERT_EQ(relu_exeuction_index, -1);
+ relu_exeuction_index = static_cast<int64_t>(n);
+ break;
+ case tflite::BuiltinOperator_MAX_POOL_2D:
+ ASSERT_EQ(maxpool_execution_index, -1);
+ maxpool_execution_index = static_cast<int64_t>(n);
+ break;
+ default:
+ break;
+ }
+ }
+
+ ASSERT_NE(maxpool_execution_index, -1);
+ ASSERT_NE(relu_exeuction_index, -1);
+ // maxpool SHOULD precede ReLU
+ ASSERT_LT(maxpool_execution_index, relu_exeuction_index);
+ }
+}
+
+/**
+ * @brief Test exporter buffer generation
+ */
+TEST_F(TFLExporterImplTests, Regression_0001)
+{
+ auto cgen = make_node<loco::ConstGen>();
+ cgen->rank(1);
+ cgen->dim(0) = 2;
+ cgen->dtype(loco::DataType::FLOAT32);
+ cgen->size<loco::DataType::FLOAT32>(2);
+ cgen->at<loco::DataType::FLOAT32>(0) = 3.3f;
+ cgen->at<loco::DataType::FLOAT32>(1) = 1.1f;
+
+ auto push = make_node<loco::Push>();
+ push->from(cgen);
+
+ auto output = graph()->outputs()->create();
+ {
+ output->name("output");
+ loco::link(output, push);
+ }
+
+ exo::TFLExporter::Impl exporter{graph()};
+ {
+ auto model = tflite::GetModel(exporter.getBufferPointer());
+ auto buffers = model->buffers();
+
+ // 0'th empty buffer + ConstGen data + ConstGen node output
+ ASSERT_EQ(buffers->Length(), 3);
+
+ // 0'th should be empty buffer
+ auto buffer_0 = (*buffers)[0];
+ auto array_0 = buffer_0->data();
+ ASSERT_EQ(array_0, nullptr);
+
+ // 1'st should be ConstGen data which is two float
+ auto buffer_1 = (*buffers)[1];
+ auto array_1 = buffer_1->data();
+ size_t size_1 = array_1->size();
+ ASSERT_EQ(size_1, 2 * sizeof(float));
+ }
+}
diff --git a/compiler/exo/src/TFLite/TFLExporterUtils.cpp b/compiler/exo/src/TFLite/TFLExporterUtils.cpp
new file mode 100644
index 000000000..d35afc9aa
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLExporterUtils.cpp
@@ -0,0 +1,160 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLExporterUtils.h"
+
+#include <oops/InternalExn.h>
+
+namespace exo
+{
+
+tflite::ActivationFunctionType to_tflite_actfunc(locoex::FusedActFunc func)
+{
+ switch (func)
+ {
+ case locoex::FusedActFunc::NONE:
+ return tflite::ActivationFunctionType_NONE;
+ case locoex::FusedActFunc::RELU:
+ return tflite::ActivationFunctionType_RELU;
+ case locoex::FusedActFunc::RELU6:
+ return tflite::ActivationFunctionType_RELU6;
+ default:
+ INTERNAL_EXN_V("Unsupported locoex FusedActFunc Type", oops::to_uint32(func));
+ }
+}
+
+} // namespace exo
+
+namespace exo
+{
+namespace tflite_detail
+{
+
+uint32_t SerializedModelData::registerBuiltinOpcode(tflite::BuiltinOperator builtin_code)
+{
+ auto it = _operator_codes.find(OpCode{builtin_code});
+ if (it != _operator_codes.end())
+ {
+ return it->second;
+ }
+ auto idx = static_cast<uint32_t>(_operator_codes.size());
+ _operator_codes.emplace(OpCode{builtin_code}, idx);
+ return idx;
+}
+
+uint32_t SerializedModelData::registerCustomOpcode(const std::string &custom_op)
+{
+ tflite::BuiltinOperator custom_code = tflite::BuiltinOperator_CUSTOM;
+ auto idx = registerBuiltinOpcode(custom_code);
+ _custom_operator_codes.emplace(OpCode{custom_code}, custom_op);
+ return idx;
+}
+
+tflite::Padding getOpPadding(const loco::Padding2D *pad, const loco::Stride<2> *stride,
+ const ShapeDescription &ifm, const ShapeDescription &ofm)
+{
+ // VALID padding
+ if (pad->top() == 0 && pad->bottom() == 0 && pad->left() == 0 && pad->right() == 0)
+ return tflite::Padding_VALID;
+
+ // SAME padding
+ //
+ // For same padding, by definition, following equation should hold:
+ // O = floor((I - 1) / S) + 1
+ // where input size I, output size O, stride S
+ //
+ // NOTE input and output 'feature' map are shape of NHWC
+ bool same_padding_criterion_1 =
+ (static_cast<uint32_t>(ofm._dims[1]) == (ifm._dims[1] - 1) / stride->vertical() + 1) &&
+ (static_cast<uint32_t>(ofm._dims[2]) == (ifm._dims[2] - 1) / stride->horizontal() + 1);
+
+ // For same padding, rear padding is same or bigger than front padding by at most 1
+ bool same_padding_criterion_2 =
+ (pad->top() <= pad->bottom()) && (pad->bottom() <= pad->top() + 1) &&
+ (pad->left() <= pad->right()) && (pad->right() <= pad->left() + 1);
+
+ if (same_padding_criterion_1 && same_padding_criterion_2)
+ return tflite::Padding_SAME;
+
+ INTERNAL_EXN("NYI for custom PAD");
+}
+
+tflite::Padding getOpPadding(const locoex::Padding pad)
+{
+ if (pad == locoex::Padding::VALID)
+ return tflite::Padding_VALID;
+ if (pad == locoex::Padding::SAME)
+ return tflite::Padding_SAME;
+
+ INTERNAL_EXN_V("Unknown padding", oops::to_uint32(pad));
+}
+
+void registerGraphIOName(loco::Graph *graph, SerializedModelData &gd)
+{
+ for (uint32_t in = 0; in < graph->inputs()->size(); ++in)
+ {
+ auto pull = loco::pull_node(graph, in);
+ auto name = graph->inputs()->at(in)->name();
+
+ gd._pull_to_name[pull] = name;
+ }
+ for (uint32_t out = 0; out < graph->outputs()->size(); ++out)
+ {
+ auto push = loco::push_node(graph, out);
+ auto name = graph->outputs()->at(out)->name();
+
+ gd._push_to_name[push] = name;
+ }
+}
+
+#include <stdex/Memory.h>
+
+#include <cassert>
+
+namespace
+{
+
+class TFLTensorIndexAnnotation final : public loco::NodeAnnotation
+{
+public:
+ TFLTensorIndexAnnotation(const TFLTensorIndex &index) : _index{index}
+ {
+ // DO NOTHING
+ }
+
+public:
+ const TFLTensorIndex &index(void) const { return _index; }
+
+private:
+ TFLTensorIndex _index;
+};
+
+} // namespace
+
+void set_tensor_index(loco::Node *node, const TFLTensorIndex &tensor_id)
+{
+ assert(node->annot<TFLTensorIndexAnnotation>() == nullptr);
+ node->annot(stdex::make_unique<TFLTensorIndexAnnotation>(tensor_id));
+}
+
+TFLTensorIndex get_tensor_index(loco::Node *node)
+{
+ assert(node->annot<TFLTensorIndexAnnotation>() != nullptr);
+ return node->annot<TFLTensorIndexAnnotation>()->index();
+}
+
+} // namespace tflite_detail
+} // namespace exo
diff --git a/compiler/exo/src/TFLite/TFLExporterUtils.h b/compiler/exo/src/TFLite/TFLExporterUtils.h
new file mode 100644
index 000000000..dbd7a52fb
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLExporterUtils.h
@@ -0,0 +1,118 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TFL_EXPORTER_UTILS_H__
+#define __TFL_EXPORTER_UTILS_H__
+
+#include "ExporterUtils.h"
+
+#include "schema_generated.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include <loco.h>
+
+#include <unordered_map>
+
+namespace exo
+{
+namespace tflite_detail
+{
+
+struct OpCode
+{
+ tflite::BuiltinOperator opcode;
+
+ bool operator==(const OpCode &rhs) const { return opcode == rhs.opcode; }
+};
+
+} // namespace tflite_detail
+} // namespace exo
+
+namespace exo
+{
+
+tflite::ActivationFunctionType to_tflite_actfunc(locoex::FusedActFunc func);
+
+} // namespace exo
+
+namespace std
+{
+
+template <> struct hash<exo::tflite_detail::OpCode>
+{
+ size_t operator()(const exo::tflite_detail::OpCode &x) const { return hash<int>()(x.opcode); }
+};
+
+} // namespace std
+
+namespace exo
+{
+namespace tflite_detail
+{
+
+/**
+ * @breif Record the information of T/F Lite SubGraph and its mapping to loco
+ */
+struct SubGraphContext
+{
+ /// @brief SubGraph input tensor id
+ std::vector<int32_t> _inputs;
+ /// @brief SubGraph output tensor id
+ std::vector<int32_t> _outputs;
+};
+
+// Prerequisites for tflite::Model object creation
+struct SerializedModelData final : public SubGraphContext
+{
+ SerializedModelData() = default;
+ SerializedModelData(const SerializedModelData &) = delete;
+
+ std::unordered_map<OpCode, uint32_t> _operator_codes;
+ std::unordered_map<OpCode, std::string> _custom_operator_codes;
+ std::vector<flatbuffers::Offset<tflite::Operator>> _operators;
+ std::vector<flatbuffers::Offset<tflite::Tensor>> _tensors;
+ std::vector<flatbuffers::Offset<tflite::Buffer>> _buffers;
+
+ // Graph input and output names
+ std::unordered_map<loco::Pull *, std::string> _pull_to_name;
+ std::unordered_map<loco::Push *, std::string> _push_to_name;
+
+ /**
+ * @brief if opcode is not registered in table of opcodes add it
+ * @param builtin_code
+ * @return idx of opcode in table of opcodes (see schema)
+ */
+ uint32_t registerBuiltinOpcode(tflite::BuiltinOperator builtin_code);
+ uint32_t registerCustomOpcode(const std::string &custom_op);
+};
+
+tflite::Padding getOpPadding(const loco::Padding2D *pad, const loco::Stride<2> *stride,
+ const ShapeDescription &ifm, const ShapeDescription &ofm);
+tflite::Padding getOpPadding(const locoex::Padding pad);
+
+/// @brief Register graph input and output names to SerializedModelData
+void registerGraphIOName(loco::Graph *graph, SerializedModelData &gd);
+
+using TFLTensorIndex = int32_t;
+
+void set_tensor_index(loco::Node *node, const TFLTensorIndex &tensor_id);
+TFLTensorIndex get_tensor_index(loco::Node *node);
+
+} // namespace tflite_detail
+} // namespace exo
+
+#endif // __TFL_EXPORTER_UTILS_H__
diff --git a/compiler/exo/src/TFLite/TFLExporterUtils.test.cpp b/compiler/exo/src/TFLite/TFLExporterUtils.test.cpp
new file mode 100644
index 000000000..d19f87d25
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLExporterUtils.test.cpp
@@ -0,0 +1,108 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLExporterUtils.h"
+
+#include <gtest/gtest.h>
+
+using namespace exo::tflite_detail;
+
+TEST(ExporterUtilsTests, getOpPadding)
+{
+ loco::Padding2D pad;
+ loco::Stride<2> stride;
+ exo::ShapeDescription ifm;
+ exo::ShapeDescription ofm;
+
+ ifm._dims.resize(4);
+ ofm._dims.resize(4);
+
+ // VALID padding
+ {
+ pad.top(0);
+ pad.bottom(0);
+ pad.left(0);
+ pad.right(0);
+
+ stride.vertical(2);
+ stride.horizontal(2);
+
+ ifm._dims[1] = 5;
+ ifm._dims[2] = 5;
+
+ ofm._dims[1] = 2;
+ ofm._dims[2] = 2;
+
+ ASSERT_EQ(getOpPadding(&pad, &stride, ifm, ofm), tflite::Padding_VALID);
+ }
+
+ // SAME padding
+ {
+ pad.top(1);
+ pad.bottom(1);
+ pad.left(1);
+ pad.right(1);
+
+ stride.vertical(2);
+ stride.horizontal(2);
+
+ ifm._dims[1] = 5;
+ ifm._dims[2] = 5;
+
+ ofm._dims[1] = 3;
+ ofm._dims[2] = 3;
+
+ ASSERT_EQ(getOpPadding(&pad, &stride, ifm, ofm), tflite::Padding_SAME);
+ }
+
+ // Custom padding 1 - Not supported by tflite
+ {
+ pad.top(2);
+ pad.bottom(0);
+ pad.left(1);
+ pad.right(1);
+
+ stride.vertical(2);
+ stride.horizontal(2);
+
+ ifm._dims[1] = 5;
+ ifm._dims[2] = 5;
+
+ ofm._dims[1] = 3;
+ ofm._dims[2] = 3;
+
+ ASSERT_ANY_THROW(getOpPadding(&pad, &stride, ifm, ofm));
+ }
+
+ // Custom padding 2 - Not supported by tflite
+ {
+ pad.top(2);
+ pad.bottom(2);
+ pad.left(2);
+ pad.right(2);
+
+ stride.vertical(2);
+ stride.horizontal(2);
+
+ ifm._dims[1] = 5;
+ ifm._dims[2] = 5;
+
+ ofm._dims[1] = 4;
+ ofm._dims[2] = 4;
+
+ ASSERT_ANY_THROW(getOpPadding(&pad, &stride, ifm, ofm));
+ }
+}
diff --git a/compiler/exo/src/TFLite/TFLOperationExporter.cpp b/compiler/exo/src/TFLite/TFLOperationExporter.cpp
new file mode 100644
index 000000000..79b5b6287
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLOperationExporter.cpp
@@ -0,0 +1,1199 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLOperationExporter.h"
+#include "TFLExporterUtils.h"
+#include "ShapeInference.h"
+
+#include "Dialect/IR/TFLNode.h"
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+#include "Check.h"
+
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+#include <loco/Service/ShapeInference.h>
+#include <locoex/COpCall.h>
+
+#include <oops/InternalExn.h>
+
+#include <flatbuffers/flexbuffers.h>
+
+using namespace flatbuffers;
+using namespace tflite;
+
+namespace
+{
+
+using namespace exo;
+using namespace exo::tflite_detail;
+
+class OperationExporter final : public locoex::TFLNodeMutableVisitor<void>,
+ public loco::CanonicalNodeMutableVisitor<void>
+{
+public:
+ OperationExporter(FlatBufferBuilder &fbb, SerializedModelData &ctx) : builder{fbb}, gd{ctx}
+ {
+ // DO NOTHING
+ }
+
+public:
+ // FOR TFLNodes
+ void visit(locoex::TFLAdd *) final;
+ void visit(locoex::TFLAveragePool2D *) final;
+ void visit(locoex::TFLConcatenation *) final;
+ void visit(locoex::TFLConst *) final{/* skip, everything is done in exportOpDefinedTensors */};
+ void visit(locoex::TFLConv2D *) final;
+ void visit(locoex::TFLDepthwiseConv2D *) final;
+ void visit(locoex::TFLDiv *) final;
+ void visit(locoex::TFLFullyConnected *) final;
+ void visit(locoex::TFLMaximum *) final;
+ void visit(locoex::TFLMaxPool2D *) final;
+ void visit(locoex::TFLMean *) final;
+ void visit(locoex::TFLMul *) final;
+ void visit(locoex::TFLRelu *) final;
+ void visit(locoex::TFLRelu6 *) final;
+ // TODO TFLReshape
+ void visit(locoex::TFLRsqrt *) final;
+ // TODO TFLSoftmax
+ void visit(locoex::TFLSqrt *) final;
+ void visit(locoex::TFLSquaredDifference *) final;
+ void visit(locoex::TFLSub *) final;
+ // TODO TFLTanh
+ void visit(locoex::TFLTranspose *) final;
+ void visit(locoex::TFLTransposeConv *) final;
+
+ // FOR canonical nodes. These will be removed later
+ void visit(loco::ReLU *) final;
+ void visit(loco::ReLU6 *) final;
+ void visit(loco::Tanh *) final;
+ void visit(loco::Push *) final { /* DO NOTHING */}
+ void visit(loco::Pull *) final { /* DO NOTHING */}
+ void visit(loco::FeatureEncode *) final;
+ void visit(loco::FeatureDecode *) final;
+ void visit(loco::FilterEncode *) final;
+ void visit(loco::DepthwiseFilterEncode *) final;
+ void visit(loco::ConstGen *) final { /* skip, everything is done in exportOpDefinedTensors */}
+ void visit(loco::MaxPool2D *) final;
+ void visit(loco::AvgPool2D *) final;
+ void visit(loco::Conv2D *) final;
+ void visit(loco::TransposedConv2D *) final;
+ void visit(loco::DepthwiseConv2D *) final;
+ void visit(loco::TensorConcat *) final;
+ void visit(loco::TensorReduce *) final;
+ void visit(loco::TensorSoftmax *) final;
+ void visit(loco::BiasEncode *) final;
+ void visit(loco::TensorBiasAdd *) final;
+ void visit(loco::FeatureBiasAdd *) final;
+ void visit(loco::EltwiseAdd *) final;
+ void visit(loco::EltwiseMax *) final;
+ void visit(loco::EltwiseMul *) final;
+ void visit(loco::EltwiseSub *) final;
+ void visit(loco::EltwiseDiv *) final;
+ void visit(loco::EltwiseSqrt *) final;
+ void visit(loco::FixedReshape *) final;
+ void visit(loco::TensorBroadcast *) final;
+ void visit(loco::TensorConstantPad *) final;
+
+ void visit(locoex::COpCall *);
+
+private:
+ /**
+ * @brief Exports TFLMaxPool2D or TFLAveragePool2D
+ *
+ * @note TFLPool2D should be one of TFLMaxPool2D or TFLAveragePool2D
+ */
+ template <class TFLPool2D>
+ void export_pool_2d(TFLPool2D *node, tflite::BuiltinOperator builtin_op);
+
+private:
+ FlatBufferBuilder &builder;
+ SerializedModelData &gd;
+};
+
+void OperationExporter::visit(locoex::TFLAdd *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateAddOptions(builder, to_tflite_actfunc(node->fusedActivationFunction()));
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_AddOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLAveragePool2D *node)
+{
+ export_pool_2d<locoex::TFLAveragePool2D>(node, tflite::BuiltinOperator_AVERAGE_POOL_2D);
+}
+
+void OperationExporter::visit(locoex::TFLConcatenation *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONCATENATION);
+ std::vector<int32_t> inputs_vec;
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+
+ for (uint32_t i = 0; i < node->numValues(); ++i)
+ inputs_vec.push_back(get_tensor_index(node->values(i)));
+
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateConcatenationOptions(builder, node->axis(),
+ to_tflite_actfunc(node->fusedActivationFunction()));
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_ConcatenationOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLConv2D *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONV_2D);
+
+ // Make input, output and options for operator
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->filter()),
+ get_tensor_index(node->bias())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ tflite::Padding padding = getOpPadding(node->padding());
+ auto options = CreateConv2DOptions(builder, padding, node->stride()->w(), node->stride()->h(),
+ to_tflite_actfunc(node->fusedActivationFunction()));
+
+ // Make CONV_2D operator
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_Conv2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLDepthwiseConv2D *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_DEPTHWISE_CONV_2D);
+
+ // Make input, output and options for operator
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->filter()),
+ get_tensor_index(node->bias())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ tflite::Padding padding = getOpPadding(node->padding());
+ auto options = CreateDepthwiseConv2DOptions(builder, padding, node->stride()->w(),
+ node->stride()->h(), node->depthMultiplier(),
+ to_tflite_actfunc(node->fusedActivationFunction()));
+
+ // Make DEPTHWISE_CONV_2D operator
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_DepthwiseConv2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLDiv *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_DIV);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateDivOptions(builder, to_tflite_actfunc(node->fusedActivationFunction()));
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_DivOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLFullyConnected *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_FULLY_CONNECTED);
+
+ // Make input, output and options for operator
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input()),
+ get_tensor_index(node->weights()),
+ get_tensor_index(node->bias())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options =
+ CreateFullyConnectedOptions(builder, to_tflite_actfunc(node->fusedActivationFunction()));
+
+ // Make FULLY_CONNECTED operator
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_FullyConnectedOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLMaximum *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MAXIMUM);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateMaximumMinimumOptions(builder);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_MaximumMinimumOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLMaxPool2D *node)
+{
+ export_pool_2d<locoex::TFLMaxPool2D>(node, tflite::BuiltinOperator_MAX_POOL_2D);
+}
+
+void OperationExporter::visit(locoex::TFLMean *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MEAN);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input()),
+ get_tensor_index(node->reduction_indices())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateReducerOptions(builder, node->keep_dims());
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_ReducerOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLMul *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MUL);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateMulOptions(builder, to_tflite_actfunc(node->fusedActivationFunction()));
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_MulOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLRelu *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->features())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLRelu6 *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU6);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->features())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+// TODO TFLReshape
+
+void OperationExporter::visit(locoex::TFLRsqrt *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RSQRT);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+// TODO TFLSoftmax
+
+void OperationExporter::visit(locoex::TFLSqrt *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_SQRT);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLSquaredDifference *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_SQUARED_DIFFERENCE);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateSquaredDifferenceOptions(builder);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_SquaredDifferenceOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLSub *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_SUB);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->x()), get_tensor_index(node->y())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateSubOptions(builder, to_tflite_actfunc(node->fusedActivationFunction()));
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_SubOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+// TODO TFLTanh
+
+void OperationExporter::visit(locoex::TFLTranspose *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_TRANSPOSE);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0)), get_tensor_index(node->arg(1))};
+ std::vector<int32_t> outputs_vec{get_tensor_index(node)};
+
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateTransposeOptions(builder);
+
+ auto op_offset =
+ CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions::BuiltinOptions_TransposeOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(locoex::TFLTransposeConv *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_TRANSPOSE_CONV);
+
+ // Make input, output and options for operator
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->inputSizes()),
+ get_tensor_index(node->filter()),
+ get_tensor_index(node->outBackprop())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ tflite::Padding padding = getOpPadding(node->padding());
+ auto options =
+ CreateTransposeConvOptions(builder, padding, node->stride()->w(), node->stride()->h());
+
+ // Make TRANSPOSE_CONV operator
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_TransposeConvOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+template <class TFLPool2D>
+void OperationExporter::export_pool_2d(TFLPool2D *node, tflite::BuiltinOperator builtin_op)
+{
+ EXO_ASSERT(builtin_op == tflite::BuiltinOperator_MAX_POOL_2D ||
+ builtin_op == tflite::BuiltinOperator_AVERAGE_POOL_2D,
+ "should be maxpool or avgpool");
+ EXO_ASSERT(node->padding() != locoex::Padding::UNDEFINED, "Padding is not set");
+
+ uint32_t op_idx = gd.registerBuiltinOpcode(builtin_op);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->value())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+
+ tflite::Padding padding = getOpPadding(node->padding());
+
+ auto options = CreatePool2DOptions(builder, padding, node->stride()->w(), node->stride()->h(),
+ node->filter()->w(), node->filter()->h(),
+ to_tflite_actfunc(node->fusedActivationFunction()));
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_Pool2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::ReLU *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::ReLU6 *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU6);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::Tanh *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_TANH);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::MaxPool2D *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MAX_POOL_2D);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ tflite::Padding padding = getOpPadding(
+ node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node));
+ auto options = CreatePool2DOptions(builder, padding, node->stride()->horizontal(),
+ node->stride()->vertical(), node->window()->horizontal(),
+ node->window()->vertical());
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_Pool2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::AvgPool2D *node)
+{
+ // TFlite only support Valid convention of average pooling
+ assert(node->convention() == loco::AvgPool2D::Convention::Valid);
+
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_AVERAGE_POOL_2D);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ tflite::Padding padding = getOpPadding(
+ node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node));
+ auto options = CreatePool2DOptions(builder, padding, node->stride()->horizontal(),
+ node->stride()->vertical(), node->window()->horizontal(),
+ node->window()->vertical());
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_Pool2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::Conv2D *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONV_2D);
+
+ // Third input of CONV_2D of tflite should be bias. We will make (and register to gd) dummy zero
+ // bias. Bias would be rank 1, have size of output kernel count, and have all zero values, i.e.
+ // zero bias.
+ auto *ker = dynamic_cast<loco::FilterEncode *>(node->ker());
+ assert(ker);
+ int32_t bias_vec_size = ShapeInference::get(ker)._dims[0]; // output kernel count
+
+ auto bias_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{bias_vec_size});
+ size_t raw_bias_vec_size = bias_vec_size * sizeof(int32_t);
+
+ std::vector<float> bias_vec_data(bias_vec_size); // initialized as zero vector
+
+ auto bias_vec_offset =
+ builder.CreateVector(reinterpret_cast<uint8_t *>(bias_vec_data.data()), raw_bias_vec_size);
+
+ auto bias_buffer_offset = CreateBuffer(builder, bias_vec_offset);
+
+ const auto bias_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+
+ gd._buffers.push_back(bias_buffer_offset);
+
+ auto bias_tensor_id = static_cast<int32_t>(gd._tensors.size());
+ auto name_offset = builder.CreateString("t_" + std::to_string(bias_tensor_id));
+
+ auto bias_tensor_offset =
+ CreateTensor(builder, bias_vec_shape_offset, TensorType_FLOAT32, bias_buffer_id, name_offset);
+ gd._tensors.push_back(bias_tensor_offset);
+
+ // Make input, output and options for operator
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm()), get_tensor_index(node->ker()),
+ bias_tensor_id};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ tflite::Padding padding = getOpPadding(
+ node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node));
+ auto options = CreateConv2DOptions(builder, padding, node->stride()->horizontal(),
+ node->stride()->vertical());
+
+ // Make CONV_2D operator
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_Conv2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::TransposedConv2D *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_TRANSPOSE_CONV);
+
+ // TRANSPOSE_CONV's first input is output shape array.
+ const int32_t outshape_vec_size = 4;
+ auto outshape_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{outshape_vec_size});
+ size_t raw_outshape_vec_size = outshape_vec_size * sizeof(int32_t);
+
+ std::vector<int32_t> outshape_vec_data(outshape_vec_size);
+ {
+ // Copy inferred output shape of node
+ auto out_feature_shape = loco::shape_get(node).as<loco::FeatureShape>();
+
+ // Feature tensor in TFlite is NHWC
+ outshape_vec_data.at(0) = out_feature_shape.count().value();
+ outshape_vec_data.at(1) = out_feature_shape.height().value();
+ outshape_vec_data.at(2) = out_feature_shape.width().value();
+ outshape_vec_data.at(3) = out_feature_shape.depth().value();
+ }
+
+ auto outshape_vec_offset = builder.CreateVector(
+ reinterpret_cast<uint8_t *>(outshape_vec_data.data()), raw_outshape_vec_size);
+
+ auto outshape_buffer_offset = CreateBuffer(builder, outshape_vec_offset);
+
+ const auto outshape_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+
+ gd._buffers.push_back(outshape_buffer_offset);
+
+ auto outshape_tensor_id = static_cast<int32_t>(gd._tensors.size());
+ auto name_offset = builder.CreateString("t_" + std::to_string(outshape_tensor_id));
+
+ auto outshape_tensor_offset = CreateTensor(builder, outshape_vec_shape_offset, TensorType_INT32,
+ outshape_buffer_id, name_offset);
+ gd._tensors.push_back(outshape_tensor_offset);
+
+ // Make input, output and options for operator
+ std::vector<int32_t> inputs_vec{outshape_tensor_id, get_tensor_index(node->ker()),
+ get_tensor_index(node->ifm())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ // NOTE input and output is inversed to use this function
+ tflite::Padding padding = getOpPadding(node->pad(), node->stride(), ShapeInference::get(node),
+ ShapeInference::get(node->ifm()));
+ auto options = CreateTransposeConvOptions(builder, padding, node->stride()->horizontal(),
+ node->stride()->vertical());
+
+ // Make TRANSPOSE_CONV operator
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_TransposeConvOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::DepthwiseConv2D *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_DEPTHWISE_CONV_2D);
+
+ // Third input of DEPTHWISE_CONV2D of tflite should be bias. We will make (and register to gd)
+ // dummy zero bias. Bias would be rank 1, have size of output kernel count, and have all zero
+ // values, i.e. zero bias.
+ auto *ker = dynamic_cast<loco::DepthwiseFilterEncode *>(node->ker());
+ assert(ker);
+
+ int32_t bias_vec_size = ShapeInference::get(ker)._dims[3]; // output_size(C*M)
+ auto bias_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{bias_vec_size});
+
+ size_t raw_bias_vec_size = bias_vec_size * sizeof(int32_t);
+ std::vector<float> bias_vec_data(bias_vec_size);
+ auto bias_vec_offset =
+ builder.CreateVector(reinterpret_cast<uint8_t *>(bias_vec_data.data()), raw_bias_vec_size);
+
+ auto bias_buffer_offset = CreateBuffer(builder, bias_vec_offset);
+
+ const auto bias_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+
+ gd._buffers.push_back(bias_buffer_offset);
+
+ auto bias_tensor_id = static_cast<int32_t>(gd._tensors.size());
+ auto name_offset = builder.CreateString("t_" + std::to_string(bias_tensor_id));
+
+ auto bias_tensor_offset =
+ CreateTensor(builder, bias_vec_shape_offset, TensorType_FLOAT32, bias_buffer_id, name_offset);
+ gd._tensors.push_back(bias_tensor_offset);
+
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm()), get_tensor_index(node->ker()),
+ bias_tensor_id};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ tflite::Padding padding = getOpPadding(
+ node->pad(), node->stride(), ShapeInference::get(node->ifm()), ShapeInference::get(node));
+
+ int32_t ifm_channel_size = ShapeInference::get(node->ifm())._dims[3];
+ // multiplier = bias_vec_size(output_size)/ifm_channel_size
+ auto options =
+ CreateDepthwiseConv2DOptions(builder, padding, node->stride()->horizontal(),
+ node->stride()->vertical(), bias_vec_size / ifm_channel_size);
+
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_DepthwiseConv2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::TensorReduce *node)
+{
+ uint32_t op_idx;
+
+ switch (node->func())
+ {
+ case loco::ReduceFunc::Mean:
+ op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MEAN);
+ break;
+
+ // TODO Support more reduce type operation
+ default:
+ INTERNAL_EXN_V("Not supported reduce type", oops::to_uint32(node->func()));
+ }
+
+ // Create a vector for axes data
+ std::vector<int32_t> axes_vec;
+ auto rank = ShapeInference::get(node->input())._dims.size();
+ for (uint32_t i = 0; i < rank; ++i)
+ if (node->axes()->defined(i))
+ axes_vec.push_back(i);
+
+ int32_t axes_vec_size = axes_vec.size();
+ auto axes_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{axes_vec_size});
+
+ size_t raw_axes_vec_size = axes_vec_size * sizeof(int32_t);
+ auto axes_vec_offset =
+ builder.CreateVector(reinterpret_cast<uint8_t *>(axes_vec.data()), raw_axes_vec_size);
+
+ auto axes_buffer_offset = CreateBuffer(builder, axes_vec_offset);
+
+ const auto axes_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+
+ gd._buffers.push_back(axes_buffer_offset);
+
+ auto axes_tensor_id = static_cast<int32_t>(gd._tensors.size());
+ auto name_offset = builder.CreateString("t_" + std::to_string(axes_tensor_id));
+
+ auto axes_tensor_offset =
+ CreateTensor(builder, axes_vec_shape_offset, TensorType_INT32, axes_buffer_id, name_offset);
+ gd._tensors.push_back(axes_tensor_offset);
+
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), axes_tensor_id};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateReducerOptions(builder, true); // true is for keep_dims option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_ReducerOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::TensorSoftmax *node)
+{
+ // TODO Support when the input rank of TensorSoftmax is not 2
+ assert(ShapeInference::get(node->input())._dims.size() == 2);
+
+ // NOTE TFLite only accepts axis when the value is last dimension
+ assert(node->axis() == ShapeInference::get(node->input())._dims.size() - 1);
+
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_SOFTMAX);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateSoftmaxOptions(builder, 1.0f);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_SoftmaxOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+/// @brief Export given node into identity, i.e. CONCATENATION with one input
+template <typename NodeT>
+void exportIdentity(NodeT *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONCATENATION);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0))};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateConcatenationOptions(builder); // use dummy 0 axis and NONE activation
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_ConcatenationOptions, options.Union());
+
+ gd._operators.push_back(op_offset);
+}
+
+/// @brief Export loco nodes as TRANSPOSE
+void exportAsTranspose(loco::Node *node, FlatBufferBuilder &builder,
+ std::vector<int32_t> &perm_vec_data, SerializedModelData &gd)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_TRANSPOSE);
+
+ auto options = CreateTransposeOptions(builder);
+
+ // Create constant tensor with perm vector
+ constexpr int perm_vec_size = 4;
+ assert(perm_vec_data.size() == perm_vec_size);
+ auto perm_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{perm_vec_size});
+ constexpr size_t raw_perm_vec_size = perm_vec_size * sizeof(int32_t);
+
+ auto perm_vec_offset =
+ builder.CreateVector(reinterpret_cast<uint8_t *>(perm_vec_data.data()), raw_perm_vec_size);
+
+ auto perm_buffer_offset = CreateBuffer(builder, perm_vec_offset);
+
+ const auto perm_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+
+ gd._buffers.push_back(perm_buffer_offset);
+
+ auto perm_tensor_id = static_cast<int32_t>(gd._tensors.size());
+ auto name_offset = builder.CreateString("t_" + std::to_string(perm_tensor_id));
+
+ auto perm_tensor_offset =
+ CreateTensor(builder, perm_vec_shape_offset, TensorType_INT32, perm_buffer_id, name_offset);
+ gd._tensors.push_back(perm_tensor_offset);
+
+ // Create permutation node
+
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0)), perm_tensor_id};
+ std::vector<int32_t> outputs_vec{get_tensor_index(node)};
+
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+
+ constexpr auto options_type = tflite::BuiltinOptions::BuiltinOptions_TransposeOptions;
+
+ auto transpose_offset =
+ CreateOperator(builder, op_idx, inputs, outputs, options_type, options.Union());
+ gd._operators.push_back(transpose_offset);
+}
+
+void OperationExporter::visit(loco::FeatureEncode *node)
+{
+ auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Feature> *>(node->encoder());
+ auto perm = encoder->perm();
+
+ if (isNHWC(perm))
+ {
+ // Note that tflite represents feature as NHWC
+ exportIdentity(node, builder, gd);
+ }
+ else
+ {
+ std::vector<int32_t> perm_vec_data(4);
+ perm_vec_data[0] = perm->axis(loco::FeatureAxis::Count);
+ perm_vec_data[1] = perm->axis(loco::FeatureAxis::Height);
+ perm_vec_data[2] = perm->axis(loco::FeatureAxis::Width);
+ perm_vec_data[3] = perm->axis(loco::FeatureAxis::Depth);
+
+ exportAsTranspose(node, builder, perm_vec_data, gd);
+ }
+}
+
+void OperationExporter::visit(loco::FeatureDecode *node)
+{
+ auto decoder = dynamic_cast<loco::PermutingDecoder<loco::Domain::Feature> *>(node->decoder());
+ auto perm = decoder->perm();
+
+ if (isNHWC(perm))
+ {
+ // Note that tflite represents feature as NHWC
+ exportIdentity(node, builder, gd);
+ }
+ else
+ {
+ std::vector<int32_t> perm_vec_data(4);
+ perm_vec_data[perm->axis(loco::FeatureAxis::Count)] = 0;
+ perm_vec_data[perm->axis(loco::FeatureAxis::Height)] = 1;
+ perm_vec_data[perm->axis(loco::FeatureAxis::Width)] = 2;
+ perm_vec_data[perm->axis(loco::FeatureAxis::Depth)] = 3;
+
+ exportAsTranspose(node, builder, perm_vec_data, gd);
+ }
+}
+
+void OperationExporter::visit(loco::FilterEncode *node)
+{
+ auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Filter> *>(node->encoder());
+ auto perm = encoder->perm();
+
+ if (isNHWC(perm))
+ {
+ // Note that tflite represents filter as NHWC
+ exportIdentity(node, builder, gd);
+ }
+ else
+ {
+ std::vector<int32_t> perm_vec_data(4);
+ // NOTE In tflite, all tensors means NHWC, so 0 = N, 1 = H, 2 = W, 3 = C
+ perm_vec_data[0] = perm->axis(loco::FilterAxis::Count);
+ perm_vec_data[1] = perm->axis(loco::FilterAxis::Height);
+ perm_vec_data[2] = perm->axis(loco::FilterAxis::Width);
+ perm_vec_data[3] = perm->axis(loco::FilterAxis::Depth);
+
+ exportAsTranspose(node, builder, perm_vec_data, gd);
+ }
+}
+
+void exportAsReshape(loco::Node *node, FlatBufferBuilder &builder,
+ std::vector<int32_t> &new_shape_vec, SerializedModelData &gd)
+{
+ // NOTE TFLite has two ways to get new shape paramter,
+ // one is by attribute 'new_shape' and the other is by input 'shape'.
+ // Therefore TFLite interpreter calculates Reshape operation correctly
+ // if one of them is valid.
+ // However, since NN runtime usually get new shape parameter by input 'shape',
+ // passing new shape only by attribute can cause some problems.
+ // Of course, the opposite situation can be occurred in the future.
+ // To prevent those problems, we pass new shape parameter not only by attribute
+ // but also by input.
+
+ auto input_shape_shape_vec_offset =
+ builder.CreateVector(std::vector<int32_t>{(int32_t)new_shape_vec.size()});
+
+ size_t input_shape_vec_size = new_shape_vec.size() * sizeof(int32_t);
+ auto input_shape_input_vec_offset =
+ builder.CreateVector(reinterpret_cast<uint8_t *>(new_shape_vec.data()), input_shape_vec_size);
+ auto input_shape_buffer_offset = CreateBuffer(builder, input_shape_input_vec_offset);
+
+ const auto input_shape_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+ gd._buffers.push_back(input_shape_buffer_offset);
+
+ auto input_shape_tensor_id = static_cast<int32_t>(gd._tensors.size());
+ auto name_offset = builder.CreateString("t_" + std::to_string(input_shape_tensor_id));
+ auto input_shape_tensor_offset = CreateTensor(
+ builder, input_shape_shape_vec_offset, TensorType_INT32, input_shape_buffer_id, name_offset);
+ gd._tensors.push_back(input_shape_tensor_offset);
+
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RESHAPE);
+
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0)), input_shape_tensor_id};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+
+ auto new_shape_vec_offset = builder.CreateVector(new_shape_vec);
+ auto options = CreateReshapeOptions(builder, new_shape_vec_offset);
+
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_ReshapeOptions, options.Union());
+
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::DepthwiseFilterEncode *node)
+{
+ auto ker = node->input(); // [H, W, C, M]
+
+ // tflite represents filter as [1, H, W, C*M] where M is multiplier.
+ std::vector<int32_t> new_shape_vec(4);
+ new_shape_vec[0] = 1;
+ new_shape_vec[1] = ShapeInference::get(ker)._dims[0];
+ new_shape_vec[2] = ShapeInference::get(ker)._dims[1];
+ new_shape_vec[3] = ShapeInference::get(ker)._dims[2] * ShapeInference::get(ker)._dims[3];
+
+ exportAsReshape(node, builder, new_shape_vec, gd);
+}
+
+void OperationExporter::visit(loco::BiasAdd<loco::Domain::Tensor> *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->value()), get_tensor_index(node->bias())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateAddOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_AddOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::FeatureBiasAdd *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->value()), get_tensor_index(node->bias())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateAddOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_AddOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+/// @brief Export CONCATENATION of **TWO** tensors only
+void OperationExporter::visit(loco::TensorConcat *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONCATENATION);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateConcatenationOptions(builder, node->axis());
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_ConcatenationOptions, options.Union());
+
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::BiasEncode *encode) { exportIdentity(encode, builder, gd); }
+
+void OperationExporter::visit(loco::EltwiseAdd *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateAddOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_AddOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::EltwiseMax *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MAXIMUM);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateMaximumMinimumOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_MaximumMinimumOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::EltwiseMul *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MUL);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateMulOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_MulOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::EltwiseSub *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_SUB);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateSubOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_SubOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::EltwiseDiv *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_DIV);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateDivOptions(builder); // dummy option
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_DivOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::EltwiseSqrt *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_SQRT);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::FixedReshape *node)
+{
+ std::vector<int32_t> new_shape_vec;
+ for (uint32_t axis = 0; axis < node->rank(); ++axis)
+ {
+ assert(node->dim(axis).known());
+ new_shape_vec.push_back(node->dim(axis).value());
+ }
+
+ exportAsReshape(node, builder, new_shape_vec, gd);
+}
+
+void OperationExporter::visit(loco::TensorBroadcast *)
+{
+ INTERNAL_EXN("TensorBroadcast should not exist in the graph");
+}
+
+void OperationExporter::visit(loco::TensorConstantPad *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_PAD);
+
+ // make padding attribute an input
+ auto padding = node->padding();
+ // get padding vector size
+ int32_t padding_vec_size = padding->rank();
+ // get byte size of vector
+ size_t padding_vec_byte_size = padding_vec_size * sizeof(int32_t) * 2; // [rank, 2]
+ // create vector for data
+ std::vector<int32_t> padding_vec_data(padding_vec_size * 2);
+ // set data
+ for (int32_t i = 0; i < padding_vec_size; i++)
+ {
+ padding_vec_data.at(i * 2) = padding->front(i);
+ padding_vec_data.at(i * 2 + 1) = padding->back(i);
+ }
+ // create FlatBuffer vector
+ auto padding_vec_ptr = builder.CreateVector(reinterpret_cast<uint8_t *>(padding_vec_data.data()),
+ padding_vec_byte_size);
+
+ // create buffer
+ auto padding_buffer_ptr = CreateBuffer(builder, padding_vec_ptr);
+ // get buffer id
+ const auto padding_buffer_id = static_cast<uint32_t>(gd._buffers.size());
+
+ gd._buffers.push_back(padding_buffer_ptr);
+
+ // create padding shape vector
+ auto padding_shape_vec_ptr = builder.CreateVector(std::vector<int32_t>{padding_vec_size, 2});
+ // create tensor
+ auto padding_tensor_ptr =
+ CreateTensor(builder, padding_shape_vec_ptr, TensorType_INT32, padding_buffer_id);
+ // get tensor id
+ const auto padding_tensor_id = static_cast<int32_t>(gd._tensors.size());
+
+ gd._tensors.push_back(padding_tensor_ptr);
+
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), padding_tensor_id};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
+inline flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
+CreateCOpCallOptions(flatbuffers::FlatBufferBuilder &fbb, locoex::COpCall *copCall)
+{
+ // read attrs in FlexBuffer format and pass them to FlatBuffer builder
+ flexbuffers::Builder flexbuf;
+ {
+ size_t map_start = flexbuf.StartMap();
+
+ // Note: among attrs of COpCall, 'op' and 'name' won't be included into tflite file
+ auto names = copCall->attr_names();
+ for (auto name : names)
+ {
+ if (auto int_val = copCall->attr<locoex::COpAttrType::Int>(name))
+ flexbuf.Int(name.c_str(), int_val->val());
+ else if (auto float_val = copCall->attr<locoex::COpAttrType::Float>(name))
+ flexbuf.Float(name.c_str(), float_val->val());
+ else
+ // TODO Support more attribute types
+ INTERNAL_EXN("Not supported type while writing flexbuffer");
+ }
+
+ flexbuf.EndMap(map_start);
+ flexbuf.Finish();
+ }
+
+ auto offset = fbb.CreateVector(flexbuf.GetBuffer());
+
+ return offset;
+}
+
+void OperationExporter::visit(locoex::COpCall *call)
+{
+ // Registering this custom op name into tflite Operator Codes table
+ uint32_t op_idx = gd.registerCustomOpcode(call->op());
+
+ std::vector<int32_t> inputs_vec;
+ {
+ inputs_vec.resize(call->arity());
+ for (uint32_t i = 0; i < call->arity(); i++)
+ inputs_vec[i] = get_tensor_index(call->arg(i));
+ }
+
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(call))};
+
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+
+ auto custom_options = CreateCOpCallOptions(builder, call);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_NONE, // builtin_options_type
+ 0, // built-in option
+ custom_options, // custom options
+ tflite::CustomOptionsFormat_FLEXBUFFERS);
+
+ gd._operators.push_back(op_offset);
+}
+
+void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder,
+ SerializedModelData &data)
+{
+ // TODO Use explicit tagging to prevent possible mistake
+ auto isNoOp = [](loco::Node *node) {
+ if (node->arity() == 1)
+ {
+ assert(node->arg(0) != nullptr);
+ return get_tensor_index(node) == get_tensor_index(node->arg(0));
+ }
+ return false;
+ };
+
+ if (isNoOp(node))
+ {
+ // Skip if a given node is marked as NoOp (op with no effect) before
+ return;
+ }
+
+ if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node))
+ { // TODO Consider removing this later
+ OperationExporter exporter{builder, data};
+ canonical_node->accept(&exporter);
+ }
+ else if (auto tfl_node = dynamic_cast<locoex::TFLNode *>(node))
+ {
+ OperationExporter exporter{builder, data};
+ tfl_node->accept(&exporter);
+ }
+ else if (dynamic_cast<locoex::COpNode *>(node))
+ {
+ OperationExporter exporter{builder, data};
+ exporter.visit(dynamic_cast<locoex::COpCall *>(node));
+ }
+ else
+ {
+ assert(false && "unsupported node found");
+ }
+}
+
+} // namespace
+
+namespace exo
+{
+namespace tflite_detail
+{
+
+void exportNodes(loco::Graph *g, FlatBufferBuilder &builder, SerializedModelData &gd)
+{
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ exportNode(node, builder, gd);
+ }
+}
+
+} // namespace tflite_detail
+} // namespace exo
diff --git a/compiler/exo/src/TFLite/TFLOperationExporter.h b/compiler/exo/src/TFLite/TFLOperationExporter.h
new file mode 100644
index 000000000..60f2b5eb2
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLOperationExporter.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TFL_OPERATION_EXPORTER_H__
+#define __TFL_OPERATION_EXPORTER_H__
+
+#include "TFLExporterUtils.h"
+
+#include <loco/IR/Graph.h>
+
+namespace exo
+{
+namespace tflite_detail
+{
+
+/**
+ * @brief create Operators corresponding to model nodes
+ * @param nodes container with nodes
+ * @param gd information about serializer parts of model
+ */
+void exportNodes(loco::Graph *g, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &gd);
+
+} // namespace tflite_detail
+} // namespace exo
+
+#endif // __TFL_OPERATION_EXPORTER_H__
diff --git a/compiler/exo/src/TFLite/TFLTensorExporter.cpp b/compiler/exo/src/TFLite/TFLTensorExporter.cpp
new file mode 100644
index 000000000..66854ef87
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLTensorExporter.cpp
@@ -0,0 +1,249 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLTensorExporter.h"
+#include "TFLTypeInference.h"
+#include "ShapeInference.h"
+
+// TODO Fix include style
+#include "loco/IR/Algorithm.h"
+#include "loco/IR/CanonicalNode.h"
+#include "loco/IR/CanonicalNodeVisitor.h"
+#include "loco/IR/DataTypeTraits.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include <oops/InternalExn.h>
+
+using namespace tflite;
+using namespace flatbuffers;
+
+namespace
+{
+
+using namespace exo;
+using namespace exo::tflite_detail;
+
+class TFLTensorInfo
+{
+public:
+ TFLTensorInfo() = default;
+
+public:
+ void name(const std::string &name) { _name = name; }
+ const std::string &name(void) const { return _name; }
+
+public:
+ const tflite::TensorType &dtype(void) const { return _dtype; }
+ void dtype(const tflite::TensorType &dtype) { _dtype = dtype; }
+
+ const ShapeDescription &shape(void) const { return _shape; }
+ void shape(const ShapeDescription &shape) { _shape = shape; }
+
+public:
+ locoex::TFLConst *tfl_content(void) const { return _tfl_content; }
+ void tfl_content(locoex::TFLConst *c) { _tfl_content = c; }
+
+private:
+ std::string _name;
+
+ tflite::TensorType _dtype;
+ ShapeDescription _shape;
+
+ // TODO Find a better design
+ loco::ConstGen *_content = nullptr; // TODO deprecate
+ locoex::TFLConst *_tfl_content = nullptr;
+};
+
+using TFLTensorContext = std::vector<TFLTensorInfo>;
+
+struct NoOpDetector final : public loco::CanonicalNodeMutableVisitor<bool>
+{
+ bool visit(loco::BiasEncode *) final
+ {
+ // BiasEncode is always noop
+ return true;
+ }
+
+ bool visit(loco::FilterEncode *node) final
+ {
+ auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Filter> *>(node->encoder());
+ auto perm = encoder->perm();
+
+ return isNHWC(perm);
+ }
+
+ bool visit(loco::FeatureEncode *node) final
+ {
+ auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Feature> *>(node->encoder());
+ auto perm = encoder->perm();
+ return isNHWC(perm);
+ }
+
+ bool visit(loco::FeatureDecode *node) final
+ {
+ auto decoder = dynamic_cast<loco::PermutingDecoder<loco::Domain::Feature> *>(node->decoder());
+ auto perm = decoder->perm();
+ return isNHWC(perm);
+ }
+
+ // Return false by default
+ bool visit(loco::Node *) final { return false; }
+};
+
+bool isNoOp(loco::Node *node)
+{
+ if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node))
+ {
+ NoOpDetector d;
+ return canonical_node->accept(&d);
+ }
+ return false;
+}
+
+void allocateTFLiteTensor(loco::Node *node, TFLTensorContext &ctx)
+{
+ if (isNoOp(node))
+ {
+ assert(node->arity() == 1 && node->arg(0) != nullptr);
+ set_tensor_index(node, get_tensor_index(node->arg(0)));
+ return;
+ }
+
+ auto tensor_index = static_cast<TFLTensorIndex>(ctx.size());
+ // TODO Use Graph-level metadata for Input & Output
+ auto tensor_name = "t_" + std::to_string(tensor_index);
+
+ TFLTensorInfo tensor_info;
+
+ tensor_info.name(tensor_name);
+ tensor_info.dtype(TypeInference::get(node));
+ tensor_info.shape(ShapeInference::get(node));
+
+ tensor_info.tfl_content(dynamic_cast<locoex::TFLConst *>(node));
+
+ set_tensor_index(node, tensor_index);
+
+ ctx.emplace_back(tensor_info);
+}
+
+} // namespace
+
+namespace
+{
+
+flatbuffers::Offset<Vector<int32_t>> encodeShape(FlatBufferBuilder &builder,
+ const ShapeDescription &shape)
+{
+ assert(shape._rank_known && "unknown number of dimensions is not supported");
+ return builder.CreateVector(shape._dims);
+}
+
+flatbuffers::Offset<tflite::Buffer> encodeOpBuffer(FlatBufferBuilder &builder)
+{
+ return CreateBuffer(builder);
+}
+
+template <typename NodeT>
+flatbuffers::Offset<tflite::Buffer> encodeOpBuffer(FlatBufferBuilder &builder, NodeT *)
+{
+ return CreateBuffer(builder);
+}
+
+template <loco::DataType DT>
+flatbuffers::Offset<tflite::Buffer> encodeOpBufferByDType(FlatBufferBuilder &builder,
+ locoex::TFLConst *c)
+{
+ using NativeType = typename loco::DataTypeImpl<DT>::Type;
+
+ std::vector<NativeType> raw_data;
+ const uint32_t size = c->size<DT>();
+ raw_data.reserve(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ raw_data.push_back(c->at<DT>(i));
+ }
+ const size_t raw_size = size * sizeof(NativeType);
+ auto array_offset = builder.CreateVector(reinterpret_cast<uint8_t *>(raw_data.data()), raw_size);
+ return CreateBuffer(builder, array_offset);
+}
+
+template <>
+flatbuffers::Offset<tflite::Buffer> encodeOpBuffer(FlatBufferBuilder &builder, locoex::TFLConst *c)
+{
+ if (c->dtype() == loco::DataType::FLOAT32)
+ {
+ return encodeOpBufferByDType<loco::DataType::FLOAT32>(builder, c);
+ }
+ else if (c->dtype() == loco::DataType::S32)
+ {
+ return encodeOpBufferByDType<loco::DataType::S32>(builder, c);
+ }
+
+ INTERNAL_EXN_V("Unsupported datatype", oops::to_uint32(c->dtype()));
+}
+
+} // namespace
+
+namespace exo
+{
+namespace tflite_detail
+{
+
+void exportOpDefinedTensor(const TFLTensorInfo &info, FlatBufferBuilder &builder,
+ SerializedModelData &gd)
+{
+ // Create and register output tensor shape
+ auto shape_offset = encodeShape(builder, info.shape());
+
+ // encode and register output tensor buffer
+ auto buffer = info.tfl_content() == nullptr ? encodeOpBuffer(builder)
+ : encodeOpBuffer(builder, info.tfl_content());
+
+ auto buffer_id = static_cast<uint32_t>(gd._buffers.size());
+ gd._buffers.push_back(buffer);
+
+ auto name_offset = builder.CreateString(info.name());
+ auto tensor_offset = CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset,
+ /*quantization*/ 0, /*is_variable*/ false);
+ gd._tensors.push_back(tensor_offset);
+}
+
+void exportOpDefinedTensors(loco::Graph *g, FlatBufferBuilder &builder, SerializedModelData &gd)
+{
+ TFLTensorContext tensor_ctx;
+
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ allocateTFLiteTensor(node, tensor_ctx);
+ }
+
+ // add one empty buffer
+ // note: there's a comment in tflite fbs file
+ // - Note the 0th entry of this array must be an empty buffer (sentinel).
+ // - This is a convention so that tensors without a buffer can provide 0 as
+ // - their buffer.
+ auto buffer = encodeOpBuffer(builder);
+ gd._buffers.push_back(buffer);
+
+ for (const auto &tensor_info : tensor_ctx)
+ {
+ exportOpDefinedTensor(tensor_info, builder, gd);
+ }
+}
+
+} // namespace tflite_detail
+} // namespace exo
diff --git a/compiler/exo/src/TFLite/TFLTensorExporter.h b/compiler/exo/src/TFLite/TFLTensorExporter.h
new file mode 100644
index 000000000..97e702665
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLTensorExporter.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TFL_TENSOR_EXPORTER_H__
+#define __TFL_TENSOR_EXPORTER_H__
+
+#include "TFLExporterUtils.h"
+
+#include <loco/IR/Graph.h>
+
+#include <flatbuffers/flatbuffers.h>
+
+namespace exo
+{
+namespace tflite_detail
+{
+
+/**
+ * @brief create Tensors corresponding to results of all nodes in graph
+ * @param computational graph
+ * @param gd information about serialized parts of model
+ */
+void exportOpDefinedTensors(loco::Graph *g, flatbuffers::FlatBufferBuilder &builder,
+ SerializedModelData &gd);
+
+} // namespace tflite_detail
+} // namespace exo
+
+#endif // __TFL_TENSOR_EXPORTER_H__
diff --git a/compiler/exo/src/TFLite/TFLTypeInference.cpp b/compiler/exo/src/TFLite/TFLTypeInference.cpp
new file mode 100644
index 000000000..8d6bb8d8c
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLTypeInference.cpp
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLTypeInference.h"
+
+#include "schema_generated.h"
+
+#include "Dialect/Service/TFLTypeInferenceRule.h"
+#include "Dialect/IR/TFLDialect.h"
+
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/Service/TypeInference.h>
+
+#include <locoex/COpDialect.h>
+#include <locoex/Service/COpTypeInference.h>
+
+#include <oops/InternalExn.h>
+
+#include <stdex/Memory.h>
+
+#include <stdexcept>
+#include <type_traits>
+
+namespace
+{
+
+tflite::TensorType translateLocoTypeToTFLite(loco::DataType dtype)
+{
+ switch (dtype)
+ {
+ case loco::DataType::U8:
+ return tflite::TensorType_UINT8;
+ // case loco::DataType::U16: unsupported
+ // case loco::DataType::U32: unsupported
+ // case loco::DataType::U64: unsupported
+ case loco::DataType::S8:
+ return tflite::TensorType_INT8;
+ case loco::DataType::S16:
+ return tflite::TensorType_INT16;
+ case loco::DataType::S32:
+ return tflite::TensorType_INT32;
+ case loco::DataType::S64:
+ return tflite::TensorType_INT64;
+ case loco::DataType::FLOAT16:
+ return tflite::TensorType_FLOAT16;
+ case loco::DataType::FLOAT32:
+ return tflite::TensorType_FLOAT32;
+ // case loco::DataType::FLOAT64: unsupported
+ default:
+ break;
+ }
+
+ INTERNAL_EXN_V("Trying to converte unsupported loco dtype", oops::to_uint32(dtype));
+}
+
+} // namespace
+
+namespace exo
+{
+
+tflite::TensorType TypeInference::get(loco::Node *node)
+{
+ assert(loco::dtype_known(node));
+ return translateLocoTypeToTFLite(loco::dtype_get(node));
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/TFLite/TFLTypeInference.h b/compiler/exo/src/TFLite/TFLTypeInference.h
new file mode 100644
index 000000000..3d3a2e480
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLTypeInference.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TFL_TYPE_INFERENCE_H__
+#define __TFL_TYPE_INFERENCE_H__
+
+#include "TFLExporterUtils.h"
+
+#include <loco/IR/Nodes.h>
+
+namespace exo
+{
+
+/**
+ * @brief Get the type of each node as NodeAnnotation
+ *
+ * HOW TO USE
+ *
+ * TypeInference::get(g->nodes()->at(0));
+ * TypeInference::get(g->nodes()->at(...));
+ */
+struct TypeInference
+{
+ static tflite::TensorType get(loco::Node *node);
+};
+
+} // namespace exo
+
+#endif // __TFL_TYPE_INFERENCE_H__
diff --git a/compiler/exo/src/TFLite/TFLTypeInference.test.cpp b/compiler/exo/src/TFLite/TFLTypeInference.test.cpp
new file mode 100644
index 000000000..0712f0a25
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLTypeInference.test.cpp
@@ -0,0 +1,118 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLTypeInference.h"
+#include "Pass/TypeInferencePass.h"
+
+#include <loco/IR/PermutingCodec.h>
+#include <stdex/Memory.h>
+
+#include <gtest/gtest.h>
+
+using stdex::make_unique;
+
+namespace
+{
+
+class Sequential
+{
+public:
+ loco::Pull *addPullLayer(const loco::DataType &dtype = loco::DataType::FLOAT32)
+ {
+ loco::Pull *pull = _graph.nodes()->create<loco::Pull>();
+
+ auto graph_input = _graph.inputs()->create();
+ graph_input->name("graph_input");
+ loco::link(graph_input, pull);
+
+ pull->dtype(dtype);
+ setSampleShape(pull);
+
+ return last(pull);
+ }
+
+ loco::ReLU *addReLULayer(void)
+ {
+ loco::ReLU *relu = _graph.nodes()->create<loco::ReLU>();
+
+ relu->input(_last);
+
+ return last(relu);
+ }
+
+ loco::Push *addPushLayer(void)
+ {
+ loco::Push *push = _graph.nodes()->create<loco::Push>();
+
+ auto graph_output = _graph.outputs()->create();
+ graph_output->name("graph_output");
+ loco::link(graph_output, push);
+
+ push->from(_last);
+
+ return last(push);
+ }
+
+ loco::Graph *graph() { return &_graph; }
+
+private:
+ template <typename T> uint32_t setSampleShape(T *op)
+ {
+ const uint32_t n = 1;
+ const uint32_t h = 100;
+ const uint32_t w = 100;
+ const uint32_t c = 3;
+ op->rank(4);
+ op->dim(0).set(n);
+ op->dim(1).set(c);
+ op->dim(2).set(h);
+ op->dim(3).set(w);
+ return n * h * w * c;
+ }
+
+ template <typename T> T *last(T *node)
+ {
+ _last = node;
+ return node;
+ }
+
+private:
+ loco::Graph _graph;
+ loco::Node *_last;
+};
+
+struct TypeInferenceTest : public Sequential, public ::testing::Test
+{
+ virtual ~TypeInferenceTest() = default;
+};
+
+} // namespace
+
+// TypeInference SHOULD PROPAGATE type information properly
+TEST_F(TypeInferenceTest, Regression_0000)
+{
+ auto pull = addPullLayer(loco::DataType::S8);
+ auto relu = addReLULayer();
+ auto push = addPushLayer();
+
+ using namespace exo;
+
+ TypeInferencePass type_inf_pass;
+ type_inf_pass.run(graph());
+
+ ASSERT_EQ(TypeInference::get(relu), tflite::TensorType_INT8);
+ ASSERT_EQ(TypeInference::get(push), tflite::TensorType_INT8);
+}
diff --git a/compiler/exo/src/TestGraph.h b/compiler/exo/src/TestGraph.h
new file mode 100644
index 000000000..f919cc9ae
--- /dev/null
+++ b/compiler/exo/src/TestGraph.h
@@ -0,0 +1,315 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TEST_GRAPH_H__
+#define __TEST_GRAPH_H__
+
+#include "Dialect/IR/TFLNodes.h"
+#include "GraphBlock.h"
+#include "TestHelper.h"
+
+#include <loco.h>
+
+#include <stdex/Memory.h>
+
+#include <cassert>
+
+namespace exo
+{
+namespace test
+{
+
+class TestGraph
+{
+public:
+ std::unique_ptr<loco::Graph> g;
+ loco::Pull *pull;
+ loco::Push *push;
+
+ TestGraph() // creates Pull and Push
+ {
+ g = loco::make_graph();
+
+ pull = g->nodes()->create<loco::Pull>();
+
+ push = g->nodes()->create<loco::Push>();
+
+ auto input = g->inputs()->create();
+ {
+ input->name("input");
+ loco::link(input, pull);
+ }
+ auto output = g->outputs()->create();
+ {
+ output->name("output");
+ loco::link(output, push);
+ }
+
+ _next_input = pull;
+ }
+
+ loco::Graph *graph() { return g.get(); }
+
+ /// @brief Creates node with NO arg and appends it to graph
+ template <class T> T *append()
+ {
+ auto node = g->nodes()->create<T>();
+ _next_input = node;
+
+ return node;
+ }
+
+ /// @brief Creates op T (arity=1) with arg1 as an input and appends it to graph
+ template <class T> T *append(loco::Node *arg1)
+ {
+ auto node = g->nodes()->create<T>();
+ setInput(node, arg1);
+ _next_input = node;
+
+ return node;
+ }
+
+ /// @brief Creates op T (arity=2) with arg1, arg2 as inputs and appends it to graph
+ template <class T> T *append(loco::Node *arg1, loco::Node *arg2)
+ {
+ auto node = g->nodes()->create<T>();
+ setInput(node, arg1, arg2);
+ _next_input = node;
+
+ return node;
+ }
+
+ /// @brief Creates op T (arity=3) with arg1, arg2, arg3 as inputs and appends it to graph
+ template <class T> T *append(loco::Node *arg1, loco::Node *arg2, loco::Node *arg3)
+ {
+ auto node = g->nodes()->create<T>();
+ setInput(node, arg1, arg2, arg3);
+ _next_input = node;
+
+ return node;
+ }
+
+ // push will get the last appended node
+ void complete() { push->from(_next_input); }
+
+ void complete(loco::Node *last_node) { push->from(last_node); }
+
+private:
+ // arity 1
+ void setInput(loco::Node *node, loco::Node *) { assert(false && "NYI"); };
+
+ void setInput(loco::AvgPool2D *node, loco::Node *input) { node->ifm(input); }
+ void setInput(loco::BiasDecode *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::BiasEncode *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::FeatureDecode *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::FeatureEncode *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::MaxPool2D *node, loco::Node *input) { node->ifm(input); }
+ void setInput(loco::Push *node, loco::Node *input) { node->from(input); };
+ void setInput(loco::ReLU *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::ReLU6 *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::Tanh *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::TensorTranspose *node, loco::Node *input) { node->input(input); };
+
+ void setInput(locoex::TFLAveragePool2D *node, loco::Node *input) { node->value(input); };
+ void setInput(locoex::TFLMaxPool2D *node, loco::Node *input) { node->value(input); };
+ void setInput(locoex::TFLRelu *node, loco::Node *input) { node->features(input); };
+ void setInput(locoex::TFLRelu6 *node, loco::Node *input) { node->features(input); };
+
+ // arity 2
+ void setInput(loco::Node *node, loco::Node *, loco::Node *) { assert(false && "NYI"); };
+
+ void setInput(loco::Conv2D *node, loco::Node *input, loco::Node *filter)
+ {
+ node->ifm(input);
+ node->ker(filter);
+ }
+
+ void setInput(loco::EltwiseAdd *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->lhs(arg1);
+ node->rhs(arg2);
+ };
+
+ void setInput(loco::FeatureBiasAdd *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->value(arg1);
+ node->bias(arg2);
+ };
+
+ void setInput(locoex::TFLAdd *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->x(arg1);
+ node->y(arg2);
+ };
+
+ void setInput(locoex::TFLMul *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->x(arg1);
+ node->y(arg2);
+ };
+
+ void setInput(locoex::TFLSub *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->x(arg1);
+ node->y(arg2);
+ };
+
+ void setInput(locoex::TFLTranspose *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->a(arg1);
+ node->perm(arg2);
+ };
+
+ // arity 3
+ void setInput(loco::Node *node, loco::Node *, loco::Node *, loco::Node *)
+ {
+ assert(false && "NYI");
+ };
+
+ void setInput(locoex::TFLConv2D *node, loco::Node *input, loco::Node *filter, loco::Node *bias)
+ {
+ node->input(input);
+ node->filter(filter);
+ node->bias(bias);
+ }
+
+private:
+ loco::Node *_next_input;
+};
+
+enum class ExampleGraphType
+{
+ FeatureBiasAdd,
+ ConstGen_ReLU,
+ FilterEncode_FilterDecode,
+ Transpose,
+
+ TFLTranspose,
+};
+
+template <ExampleGraphType T> class ExampleGraph;
+
+/**
+ * @brief Class to create the following:
+ *
+ * Pull - FeatureEncoder - FeatureBiasAdd - FeatureDecode - Push
+ * |
+ * ConstGen - BiasEncode --+
+ */
+template <> class ExampleGraph<ExampleGraphType::FeatureBiasAdd> : public TestGraph
+{
+public:
+ loco::FeatureEncode *fea_enc = nullptr;
+ loco::ConstGen *constgen = nullptr;
+ loco::BiasEncode *bias_enc = nullptr;
+ loco::FeatureBiasAdd *fea_bias_add = nullptr;
+ loco::FeatureDecode *fea_dec = nullptr;
+
+public:
+ ExampleGraph()
+ {
+ fea_enc = exo::make_feature_encode<exo::FeatureLayout::NHWC>(pull);
+ constgen = append<loco::ConstGen>();
+ bias_enc = append<loco::BiasEncode>(constgen);
+ fea_bias_add = append<loco::FeatureBiasAdd>(fea_enc, bias_enc);
+ fea_dec = exo::make_feature_decode<exo::FeatureLayout::NHWC>(fea_bias_add);
+ complete(fea_dec);
+ }
+};
+
+/**
+ * @brief Class to creates the following:
+ *
+ * ConstGen -- ReLU -- Push
+ */
+template <> class ExampleGraph<ExampleGraphType::ConstGen_ReLU> : public TestGraph
+{
+public:
+ loco::ConstGen *constgen = nullptr;
+ loco::ReLU *relu = nullptr;
+
+public:
+ ExampleGraph()
+ {
+ constgen = append<loco::ConstGen>();
+ relu = append<loco::ReLU>(constgen);
+ complete(relu);
+ }
+};
+
+/**
+ * @brief Class to creates the following:
+ *
+ * Pull -- Transpose -- Push
+ */
+template <> class ExampleGraph<ExampleGraphType::Transpose> : public TestGraph
+{
+public:
+ loco::TensorTranspose *transpose = nullptr;
+
+public:
+ ExampleGraph()
+ {
+ transpose = append<loco::TensorTranspose>(pull);
+ complete(transpose);
+ }
+};
+
+/**
+ * @brief Class to creates the following:
+ *
+ * Pull -- FilterEncode -- FilterDecode -- Push
+ */
+template <> class ExampleGraph<ExampleGraphType::FilterEncode_FilterDecode> : public TestGraph
+{
+public:
+ loco::FilterEncode *filterEncode = nullptr;
+ loco::FilterDecode *filterDecode = nullptr;
+
+public:
+ ExampleGraph()
+ {
+ filterEncode = exo::make_filter_encode<exo::FilterLayout::HWIO>(pull); // from Tensorflow
+ filterDecode =
+ exo::make_filter_decode<exo::FilterLayout::OHWI>(filterEncode); // to Tensorflow Lite
+ complete(filterDecode);
+ }
+};
+
+/**
+ * @brief Class to create the following:
+ *
+ * Pull -- TFLTranspose -- Push
+ */
+template <> class ExampleGraph<ExampleGraphType::TFLTranspose> : public TestGraph
+{
+public:
+ loco::ConstGen *const_perm = nullptr;
+ locoex::TFLTranspose *tfl_transpose = nullptr;
+
+public:
+ ExampleGraph()
+ {
+ const_perm = append<loco::ConstGen>();
+ tfl_transpose = append<locoex::TFLTranspose>(pull, const_perm);
+ complete(tfl_transpose);
+ }
+};
+
+} // namespace test
+} // namespace exo
+
+#endif // __TEST_GRAPH_H__
diff --git a/compiler/exo/src/TestHelper.h b/compiler/exo/src/TestHelper.h
new file mode 100644
index 000000000..1a3de50f5
--- /dev/null
+++ b/compiler/exo/src/TestHelper.h
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TEST_HELPER_H__
+#define __TEST_HELPER_H__
+
+#include "Check.h"
+#include "ProgressReporter.h"
+#include "Passes.h"
+
+#include <logo/Pass.h>
+#include <logo/Phase.h>
+
+#include <loco.h>
+
+#include <stdex/Memory.h>
+
+#include <gtest/gtest.h>
+
+/**
+ * @brief Check the number of nodes in a graph starting from OUTPUTS
+ */
+#define EXO_TEST_ASSERT_NODE_COUNT(OUTPUTS, COUNT) \
+ { \
+ auto v = loco::postorder_traversal(OUTPUTS); \
+ ASSERT_EQ(v.size(), (COUNT)); \
+ }
+
+namespace exo
+{
+namespace test
+{
+
+/**
+ * @brief Phase for test, that is used to test pass. This phase initially adds TypeInferencePass
+ * and ShapeInferencePass
+ */
+class TypeShapeReadyPhase
+{
+public:
+ TypeShapeReadyPhase()
+ {
+ // Type and Shape inference is prerequisite for run other test
+ _phase.emplace_back(stdex::make_unique<::exo::TypeInferencePass>());
+ _phase.emplace_back(stdex::make_unique<::exo::ShapeInferencePass>());
+ }
+
+ template <typename PassT> void add_pass() { _phase.emplace_back(stdex::make_unique<PassT>()); }
+
+ void run(loco::Graph *g)
+ {
+ const auto restart = logo::PhaseStrategy::Restart;
+ logo::PhaseRunner<restart> phase_runner{g};
+
+ ::exo::ProgressReporter prog(g, restart);
+ phase_runner.attach(&prog);
+ phase_runner.run(_phase);
+ }
+
+private:
+ logo::Phase _phase;
+};
+
+/**
+ * @brief Get the only succ object of type LocoNodeT. (The name `only succ` comes from English word
+ * `only child`.)
+ * parent must have 1 succ only.
+ * When there is no succ of type LocoNodeT, nullptr will be returned.
+ */
+template <typename LocoNodeT> inline LocoNodeT *get_only_succ(loco::Node *parent)
+{
+ auto succs = loco::succs(parent);
+ EXO_ASSERT(succs.size() == 1, "parent has more than 1 succs.");
+
+ return dynamic_cast<LocoNodeT *>(*succs.begin());
+}
+
+template <typename T> inline T *find_first_node_bytype(loco::Graph *g)
+{
+ T *first_node = nullptr;
+ loco::Graph::NodeContext *nodes = g->nodes();
+ uint32_t count = nodes->size();
+
+ for (uint32_t i = 0; i < count; ++i)
+ {
+ first_node = dynamic_cast<T *>(nodes->at(i));
+ if (first_node != nullptr)
+ break;
+ }
+
+ return first_node;
+}
+
+} // namespace test
+} // namespace exo
+
+#endif // __TEST_HELPER_H__