summaryrefslogtreecommitdiff
path: root/compiler/luci/import/src/Importer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/import/src/Importer.cpp')
-rw-r--r--compiler/luci/import/src/Importer.cpp253
1 files changed, 253 insertions, 0 deletions
diff --git a/compiler/luci/import/src/Importer.cpp b/compiler/luci/import/src/Importer.cpp
new file mode 100644
index 000000000..964c47633
--- /dev/null
+++ b/compiler/luci/import/src/Importer.cpp
@@ -0,0 +1,253 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Importer.h"
+
+#include "luci/Import/GraphBuilder.h"
+#include "luci/Import/GraphBuilderContext.h"
+#include "luci/Import/GraphBuilderRegistry.h"
+#include "luci/Import/CircleReader.h"
+#include "luci/Import/Nodes/CircleConst.h"
+
+#include <luci/IR/Module.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/Log.h>
+#include <luci/LogHelper.h>
+
+#include <oops/UserExn.h>
+
+#include <memory>
+
+namespace
+{
+
+void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &reader,
+ loco::Graph *graph)
+{
+ LOGGER(l);
+
+ auto nodefinder = std::make_unique<luci::IndexNodeFinder>();
+
+ luci::GraphBuilderContext gb_context(graph, &reader, nodefinder.get());
+
+ const auto &operators = reader.operators();
+ const auto &tensors = reader.tensors();
+
+ // graph inputs; there are no input nodes in TFlite but just Tensors
+ // creating virtual input nodes will make possible to connect nodes that uses them
+ // all attributes of tensor should be copied to CircleInput node
+ for (const auto input : reader.inputs())
+ {
+ auto input_node = graph->nodes()->create<luci::CircleInput>();
+ assert(input_node != nullptr);
+ const circle::TensorT &tensor = *tensors[input];
+
+ auto tname = luci::tensor_name(tensor);
+ input_node->name(tname);
+ auto quantization = luci::tensor_quantization(tensor);
+ if (quantization)
+ {
+ auto quantparam = luci::luci_quantparam(quantization);
+ if (quantparam.get())
+ input_node->quantparam(std::move(quantparam));
+ }
+
+ INFO(l) << "[luci] NodeFinder INPUT(" << input << ") = " << input_node << std::endl;
+ nodefinder->enroll(input, input_node);
+
+ // Shape of Input
+ const std::vector<int32_t> &input_dims = tensor.shape; // in NHWC
+ input_node->rank(input_dims.size());
+ for (uint32_t r = 0; r < input_dims.size(); ++r)
+ input_node->dim(r) = loco::Dimension(input_dims[r]);
+
+ // Data type of Input
+ auto dtype = luci::luci_datatype(tensor.type);
+ input_node->dtype(dtype);
+
+ // Name
+ auto graph_input = graph->inputs()->create();
+ graph_input->name(tname);
+
+ // Set GraphInputOutputIndex for graph
+ input_node->index(graph_input->index());
+
+ // Data type
+ graph_input->dtype(dtype);
+ }
+
+ // Create CircleConst nodes for constant tensors.
+ const auto &buffers = reader.buffers();
+ for (uint32_t i = 0; i < tensors.size(); ++i)
+ {
+ const circle::TensorT &tensor = *tensors[i];
+ const std::vector<uint8_t> &buffer = buffers[tensor.buffer]->data;
+ if (!buffer.empty())
+ {
+ luci::CircleConst *const_node = luci::create_circleconst(&gb_context, i);
+ nodefinder->enroll(i, const_node);
+ }
+ }
+
+ // Import the operators.
+ // Note that operators in model are stored in execution order. This means that when importing
+ // an operator, its input operators have already been imported. We exploit this fact to set up
+ // node's inputs right after creating the node.
+ for (uint32_t i = 0; i < operators.size(); ++i)
+ {
+ const circle::OperatorT &op = *operators[i];
+ circle::BuiltinOperator builtincode = reader.builtin_code(op);
+
+ if (const auto *builder = source.lookup(builtincode))
+ {
+ luci::GraphBuilder::ValidateArgs args(op, reader);
+ if (!builder->validate(args))
+ {
+ throw oops::UserExn("Invalid operator", reader.opcode_name(op));
+ }
+
+ builder->build(op, &gb_context);
+ }
+ else
+ {
+ throw oops::UserExn("Not supported", reader.opcode_name(op));
+ }
+ }
+
+ // graph outputs
+ for (auto output : reader.outputs())
+ {
+ auto output_node = graph->nodes()->create<luci::CircleOutput>();
+ assert(output_node != nullptr);
+ output_node->from(nodefinder->node(output));
+
+ INFO(l) << "[luci] NodeFinder OUTPUT(" << output << ") = " << output_node << std::endl;
+
+ // set the graph output name and node object
+ const circle::TensorT &tensor = *tensors[output];
+ auto graph_output = graph->outputs()->create();
+ std::string tname = luci::tensor_name(tensor);
+ graph_output->name("output_" + tname);
+
+ // Set GraphInputOutputIndex for graph
+ output_node->index(graph_output->index());
+
+ // Shape of Output
+ auto output_shape = std::make_unique<loco::TensorShape>();
+ const std::vector<int32_t> &output_dims = tensor.shape; // in NHWC
+ output_shape->rank(output_dims.size());
+ for (uint32_t r = 0; r < output_dims.size(); ++r)
+ output_shape->dim(r) = loco::Dimension(output_dims[r]);
+ graph_output->shape(std::move(output_shape));
+
+ // Data type
+ auto dtype = luci::luci_datatype(tensor.type);
+ graph_output->dtype(dtype);
+ }
+}
+
+class ValidateCollector final : public loco::ErrorListener
+{
+public:
+ void notify(const loco::ErrorDetail<loco::ErrorCategory::MissingArgument> &d) override
+ {
+ LOGGER(l);
+ INFO(l) << "[luci] GraphValidate error " << d.node() << "(" << d.index() << ")" << std::endl;
+ }
+};
+
+} // namespace
+
+namespace luci
+{
+
+Importer::Importer()
+{
+ // DO NOTHING
+}
+
+std::unique_ptr<loco::Graph> Importer::import(const circle::Model *model) const
+{
+ auto graph = loco::make_graph();
+
+ const GraphBuilderSource *source_ptr = &GraphBuilderRegistry::get();
+
+ if (_source != nullptr)
+ {
+ // Use user-defined GraphBuilderSource
+ source_ptr = _source;
+ }
+
+ CircleReader reader;
+ if (!reader.parse(model))
+ return nullptr;
+
+ // TODO support multiple subgraph when Circle supports
+ assert(reader.num_subgraph() == 1);
+ if (!reader.select_subgraph(0))
+ return nullptr;
+
+ // Convert circle::Model to loco::Graph
+ convert_graph(*source_ptr, reader, graph.get());
+
+ LOGGER(l);
+ INFO(l) << fmt(graph.get());
+
+ assert(loco::valid(graph.get(), std::make_unique<ValidateCollector>()));
+
+ return std::move(graph);
+}
+
+std::unique_ptr<Module> Importer::importModule(const circle::Model *model) const
+{
+ auto module = make_module();
+
+ const GraphBuilderSource *source_ptr = &GraphBuilderRegistry::get();
+
+ if (_source != nullptr)
+ {
+ // Use user-defined GraphBuilderSource
+ source_ptr = _source;
+ }
+
+ CircleReader reader;
+ if (!reader.parse(model))
+ return nullptr;
+
+ for (uint32_t g = 0; g < reader.num_subgraph(); ++g)
+ {
+ auto graph = loco::make_graph();
+
+ if (!reader.select_subgraph(g))
+ return nullptr;
+
+ graph->name(reader.name());
+
+ // Convert circle::Model to loco::Graph
+ convert_graph(*source_ptr, reader, graph.get());
+
+ LOGGER(l);
+ INFO(l) << fmt(graph.get());
+
+ assert(loco::valid(graph.get(), std::make_unique<ValidateCollector>()));
+
+ module->add(std::move(graph));
+ }
+
+ return std::move(module);
+}
+
+} // namespace luci