summaryrefslogtreecommitdiff
path: root/runtime/onert/frontend/circle/src/circle_loader.cc
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/onert/frontend/circle/src/circle_loader.cc')
-rw-r--r--runtime/onert/frontend/circle/src/circle_loader.cc134
1 files changed, 134 insertions, 0 deletions
diff --git a/runtime/onert/frontend/circle/src/circle_loader.cc b/runtime/onert/frontend/circle/src/circle_loader.cc
new file mode 100644
index 000000000..49aaccc4c
--- /dev/null
+++ b/runtime/onert/frontend/circle/src/circle_loader.cc
@@ -0,0 +1,134 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "circle_loader.h"
+#include "base_loader.h"
+#include "circle_schema_generated.h"
+
+namespace onert
+{
+namespace circle_loader
+{
+
+namespace
+{
+
+ir::Layout convertDataFormat(circle::DataFormat data_format)
+{
+ switch (data_format)
+ {
+ case circle::DataFormat::DataFormat_CHANNELS_FIRST:
+ return ir::Layout::NCHW;
+ case circle::DataFormat::DataFormat_CHANNELS_LAST:
+ return ir::Layout::NHWC;
+ default:
+ throw std::runtime_error("Unsupported DataFormat");
+ }
+}
+
+struct LoaderDomain
+{
+ using Verifier = flatbuffers::Verifier;
+ using ActivationFunctionType = circle::ActivationFunctionType;
+ using Buffer = circle::Buffer;
+ using BuiltinOperator = circle::BuiltinOperator;
+ using CustomOptionsFormat = circle::CustomOptionsFormat;
+ using Model = circle::Model;
+ using Operator = circle::Operator;
+ using Padding = circle::Padding;
+ using Pool2DOptions = circle::Pool2DOptions;
+ using Tensor = circle::Tensor;
+ using TensorType = circle::TensorType;
+ using SubGraph = circle::SubGraph;
+
+ static const char *EnumNameBuiltinOperator(BuiltinOperator e)
+ {
+ return circle::EnumNameBuiltinOperator(e);
+ }
+ static const char *EnumNameActivationFunctionType(ActivationFunctionType e)
+ {
+ return circle::EnumNameActivationFunctionType(e);
+ }
+ static const char *EnumNameTensorType(TensorType e) { return circle::EnumNameTensorType(e); }
+ static const Model *GetModel(const void *buf) { return circle::GetModel(buf); }
+ static bool VerifyModelBuffer(Verifier &verifier) { return circle::VerifyModelBuffer(verifier); }
+};
+
+class CircleLoader final : public base_loader::BaseLoader<LoaderDomain, CircleLoader>
+{
+public:
+ using BaseLoader::BaseLoader;
+
+ std::unique_ptr<ir::Graph> loadSubgraph(const circle::SubGraph *circle_subg)
+ {
+ auto subg = std::make_unique<ir::Graph>();
+ // Load tensors
+ _tensor_to_operand.resize(circle_subg->tensors()->size());
+ for (flatbuffers::uoffset_t i = 0; i < circle_subg->tensors()->size(); ++i)
+ {
+ _tensor_to_operand[i] = loadOperand(circle_subg->tensors()->Get(i), *subg);
+ }
+ // Set inputs
+ for (const std::int32_t input_ind : *circle_subg->inputs())
+ {
+ subg->addInput(_tensor_to_operand[input_ind]);
+ }
+ // Set outputs
+ for (const std::int32_t output_ind : *circle_subg->outputs())
+ {
+ subg->addOutput(_tensor_to_operand[output_ind]);
+ }
+ // Create operations
+ for (const auto *op : *circle_subg->operators())
+ {
+ CircleLoader::loadOperation(op, *subg);
+ }
+
+ subg->setLayout(convertDataFormat(circle_subg->data_format()));
+
+ subg->finishBuilding();
+
+ return subg;
+ }
+
+ void loadOperation(const circle::Operator *op, ir::Graph &subg)
+ {
+ const auto builtin_op = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
+
+ switch (builtin_op)
+ {
+ case circle::BuiltinOperator::BuiltinOperator_INSTANCE_NORM:
+ loadInstanceNorm(op, subg);
+ return;
+ default:
+ BaseLoader::loadOperation(op, subg);
+ return;
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<ir::Graph> loadModel(const char *filename)
+{
+ auto primary_subgraph = std::make_unique<ir::Graph>();
+ CircleLoader loader(primary_subgraph);
+ loader.loadFromFile(filename);
+ return primary_subgraph;
+}
+
+} // namespace circle_loader
+} // namespace onert