summaryrefslogtreecommitdiff
path: root/runtime/onert/frontend/tflite/src/tflite_loader.cc
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/onert/frontend/tflite/src/tflite_loader.cc')
-rw-r--r--runtime/onert/frontend/tflite/src/tflite_loader.cc110
1 files changed, 110 insertions, 0 deletions
diff --git a/runtime/onert/frontend/tflite/src/tflite_loader.cc b/runtime/onert/frontend/tflite/src/tflite_loader.cc
new file mode 100644
index 000000000..7ede4441a
--- /dev/null
+++ b/runtime/onert/frontend/tflite/src/tflite_loader.cc
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "tflite_loader.h"
+#include "base_loader.h"
+#include "tflite_schema_generated.h"
+
+namespace onert
+{
+namespace tflite_loader
+{
+
+namespace
+{
+
+struct LoaderDomain
+{
+ using Verifier = flatbuffers::Verifier;
+ using ActivationFunctionType = onert_tflite::ActivationFunctionType;
+ using Buffer = onert_tflite::Buffer;
+ using BuiltinOperator = onert_tflite::BuiltinOperator;
+ using CustomOptionsFormat = onert_tflite::CustomOptionsFormat;
+ using Model = onert_tflite::Model;
+ using Operator = onert_tflite::Operator;
+ using Padding = onert_tflite::Padding;
+ using Pool2DOptions = onert_tflite::Pool2DOptions;
+ using Tensor = onert_tflite::Tensor;
+ using TensorType = onert_tflite::TensorType;
+ using SubGraph = onert_tflite::SubGraph;
+
+ static const char *EnumNameBuiltinOperator(BuiltinOperator e)
+ {
+ return onert_tflite::EnumNameBuiltinOperator(e);
+ }
+ static const char *EnumNameActivationFunctionType(ActivationFunctionType e)
+ {
+ return onert_tflite::EnumNameActivationFunctionType(e);
+ }
+ static const char *EnumNameTensorType(TensorType e)
+ {
+ return onert_tflite::EnumNameTensorType(e);
+ }
+ static const Model *GetModel(const void *buf) { return onert_tflite::GetModel(buf); }
+ static bool VerifyModelBuffer(Verifier &verifier)
+ {
+ return onert_tflite::VerifyModelBuffer(verifier);
+ }
+};
+
+class TFLiteLoader final : public base_loader::BaseLoader<LoaderDomain, TFLiteLoader>
+{
+public:
+ using BaseLoader::BaseLoader;
+
+ std::unique_ptr<ir::Graph> loadSubgraph(const onert_tflite::SubGraph *tflite_subg)
+ {
+ auto subg = std::make_unique<ir::Graph>();
+ // Load tensors
+ _tensor_to_operand.resize(tflite_subg->tensors()->size());
+ for (flatbuffers::uoffset_t i = 0; i < tflite_subg->tensors()->size(); ++i)
+ {
+ _tensor_to_operand[i] = loadOperand(tflite_subg->tensors()->Get(i), *subg);
+ }
+ // Set inputs
+ for (const std::int32_t input_ind : *tflite_subg->inputs())
+ {
+ subg->addInput(_tensor_to_operand[input_ind]);
+ }
+ // Set outputs
+ for (const std::int32_t output_ind : *tflite_subg->outputs())
+ {
+ subg->addOutput(_tensor_to_operand[output_ind]);
+ }
+ // Create operations
+ for (const auto *op : *tflite_subg->operators())
+ {
+ loadOperation(op, *subg);
+ }
+
+ subg->finishBuilding();
+
+ return subg;
+ }
+};
+
+} // namespace
+
+std::unique_ptr<ir::Graph> loadModel(const char *filename)
+{
+ auto primary_subgraph = std::make_unique<ir::Graph>();
+ TFLiteLoader loader(primary_subgraph);
+ loader.loadFromFile(filename);
+ return primary_subgraph;
+}
+
+} // namespace tflite_loader
+} // namespace onert