summaryrefslogtreecommitdiff
path: root/compiler/tflite2circle/src/CircleModel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/tflite2circle/src/CircleModel.cpp')
-rw-r--r--compiler/tflite2circle/src/CircleModel.cpp238
1 files changed, 238 insertions, 0 deletions
diff --git a/compiler/tflite2circle/src/CircleModel.cpp b/compiler/tflite2circle/src/CircleModel.cpp
new file mode 100644
index 000000000..3a569323c
--- /dev/null
+++ b/compiler/tflite2circle/src/CircleModel.cpp
@@ -0,0 +1,238 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <iostream>
+
+#include "CircleModel.h"
+#include "DataLookup.h"
+
+namespace tflite2circle
+{
+
+template <>
+Offset<MetaDataBufferLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
+{
+ if (tflite_flatbuffer_vec == nullptr)
+ return;
+ std::vector<int32_t> metadata_buffer_vec{tflite_flatbuffer_vec->begin(),
+ tflite_flatbuffer_vec->end()};
+ _circle_flatbuffer_vec_offset = fb->CreateVector(metadata_buffer_vec);
+}
+
+template <>
+Offset<BufferLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
+{
+ std::vector<flatbuffers::Offset<circle::Buffer>> buffers_vec;
+
+ for (auto it : *tflite_flatbuffer_vec)
+ {
+ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> buffer_data;
+ if (it->data())
+ {
+ std::vector<uint8_t> data_vec{it->data()->begin(), it->data()->end()};
+ buffer_data = fb->CreateVector(data_vec);
+ }
+ circle::BufferBuilder circle_buffer_builder{*fb};
+ circle_buffer_builder.add_data(buffer_data);
+ auto circle_buffers = circle_buffer_builder.Finish();
+ buffers_vec.emplace_back(circle_buffers);
+ }
+ _circle_flatbuffer_vec_offset = fb->CreateVector(buffers_vec);
+}
+
+template <>
+Offset<SubGraphLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
+{
+ std::vector<flatbuffers::Offset<circle::SubGraph>> subgprahs_vec;
+
+ for (auto it_sg : *tflite_flatbuffer_vec)
+ {
+ // tensors of subgraph
+ std::vector<flatbuffers::Offset<circle::Tensor>> tensor_vec;
+
+ auto tflite_tensors = it_sg->tensors();
+ for (auto it : *tflite_tensors)
+ {
+ // shape
+ std::vector<int32_t> shape_vec{it->shape()->begin(), it->shape()->end()};
+ auto shape = fb->CreateVector(shape_vec);
+ // name
+ flatbuffers::Offset<flatbuffers::String> name;
+ if (it->name())
+ name = fb->CreateString(it->name()->str());
+ // quantization
+ flatbuffers::Offset<circle::QuantizationParameters> quantization;
+ if (it->quantization())
+ {
+ std::vector<float> tfmin;
+ std::vector<float> tfmax;
+ std::vector<float> tfscale;
+ std::vector<int64_t> tfzerop;
+ flatbuffers::Offset<flatbuffers::Vector<float>> min;
+ flatbuffers::Offset<flatbuffers::Vector<float>> max;
+ flatbuffers::Offset<flatbuffers::Vector<float>> scale;
+ flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point;
+
+ if (it->quantization()->min() && it->quantization()->max())
+ {
+ auto rmin = it->quantization()->min();
+ auto rmax = it->quantization()->max();
+ tfmin = std::vector<float>{rmin->begin(), rmin->end()};
+ tfmax = std::vector<float>{rmax->begin(), rmax->end()};
+ min = fb->CreateVector(tfmin);
+ max = fb->CreateVector(tfmax);
+ }
+
+ if (it->quantization()->scale() && it->quantization()->zero_point())
+ {
+ auto rs = it->quantization()->scale();
+ auto rz = it->quantization()->zero_point();
+ tfscale = std::vector<float>{rs->begin(), rs->end()};
+ tfzerop = std::vector<int64_t>{rz->begin(), rz->end()};
+ scale = fb->CreateVector(tfscale);
+ zero_point = fb->CreateVector(tfzerop);
+ }
+
+ quantization = circle::CreateQuantizationParameters(*fb, min, max, scale, zero_point);
+ }
+ // is_variable
+ bool is_variable = it->is_variable();
+
+ circle::TensorBuilder tensor_builder{*fb};
+ tensor_builder.add_shape(shape);
+ tensor_builder.add_type(get_circle_tensortype(it->type()));
+ tensor_builder.add_buffer(it->buffer());
+ tensor_builder.add_name(name);
+ tensor_builder.add_quantization(quantization);
+ tensor_builder.add_is_variable(is_variable);
+ auto tensor = tensor_builder.Finish();
+ tensor_vec.emplace_back(tensor);
+ }
+ auto circle_tensors = fb->CreateVector(tensor_vec);
+
+ // inputs of subgraph
+ auto tflite_inputs = it_sg->inputs();
+ std::vector<int32_t> input_vec{tflite_inputs->begin(), tflite_inputs->end()};
+
+ auto circle_inputs = fb->CreateVector(input_vec);
+
+ // outputs of subgraph
+ auto tflite_outputs = it_sg->outputs();
+ std::vector<int32_t> output_vec{tflite_outputs->begin(), tflite_outputs->end()};
+
+ auto circle_outputs = fb->CreateVector(output_vec);
+
+ // operators of subgraph
+ std::vector<flatbuffers::Offset<circle::Operator>> operator_vec;
+
+ auto tflite_operators = it_sg->operators();
+ for (auto it : *tflite_operators)
+ {
+ // inputs
+ std::vector<int32_t> input_vec{it->inputs()->begin(), it->inputs()->end()};
+ auto circle_inputs = fb->CreateVector(input_vec);
+ // outputs
+ std::vector<int32_t> output_vec{it->outputs()->begin(), it->outputs()->end()};
+ auto circle_outputs = fb->CreateVector(output_vec);
+ // builtin options
+ auto circle_builtin_options = get_circle_builtin_options(*fb, it);
+ auto circle_builtin_options_type = get_circle_builtin_options_type(it);
+
+ circle::OperatorBuilder operator_builder{*fb};
+ operator_builder.add_opcode_index(it->opcode_index());
+ operator_builder.add_inputs(circle_inputs);
+ operator_builder.add_outputs(circle_outputs);
+ operator_builder.add_builtin_options(circle_builtin_options);
+ operator_builder.add_builtin_options_type(circle_builtin_options_type);
+ // TODO custom_options, mutating_variable_inputs
+ auto opeartor = operator_builder.Finish();
+ operator_vec.emplace_back(opeartor);
+ }
+ auto circle_operators = fb->CreateVector(operator_vec);
+
+ // name of subgraph
+ auto subgraphs_name = fb->CreateString(it_sg->name());
+
+ // subgraphs
+ auto circle_subgraph_builder = circle::SubGraphBuilder{*fb};
+
+ circle_subgraph_builder.add_tensors(circle_tensors);
+ circle_subgraph_builder.add_inputs(circle_inputs);
+ circle_subgraph_builder.add_outputs(circle_outputs);
+ circle_subgraph_builder.add_operators(circle_operators);
+ circle_subgraph_builder.add_name(subgraphs_name);
+ circle_subgraph_builder.add_data_format(circle::DataFormat_CHANNELS_LAST);
+
+ auto circle_subgraph = circle_subgraph_builder.Finish();
+ subgprahs_vec.emplace_back(circle_subgraph);
+ }
+ _circle_flatbuffer_vec_offset = fb->CreateVector(subgprahs_vec);
+}
+
+template <>
+Offset<OperatorCodeLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
+{
+ std::vector<flatbuffers::Offset<circle::OperatorCode>> operator_code_vec;
+
+ for (auto it : *tflite_flatbuffer_vec)
+ {
+ auto custom_code = fb->CreateString(it->custom_code());
+ circle::OperatorCodeBuilder operator_code_builder{*fb};
+ operator_code_builder.add_builtin_code(get_circle_builtin_code(it->builtin_code()));
+ operator_code_builder.add_custom_code(custom_code);
+ operator_code_builder.add_version(it->version());
+ auto code = operator_code_builder.Finish();
+ operator_code_vec.emplace_back(code);
+ }
+ _circle_flatbuffer_vec_offset = fb->CreateVector(operator_code_vec);
+}
+
+CircleModel::CircleModel(FlatBufBuilder &fb, TFLModel &model)
+ : _version{0}, _description{fb->CreateString("nnpackage")}, _fb{fb}
+{
+ const tflite::Model *tfl_model = model.load_model();
+ _operator_codes_offset =
+ stdex::make_unique<Offset<OperatorCodeLink>>(fb, tfl_model->operator_codes());
+ _subGraphs_offset = stdex::make_unique<Offset<SubGraphLink>>(fb, tfl_model->subgraphs());
+ _buffers_offset = stdex::make_unique<Offset<BufferLink>>(fb, tfl_model->buffers());
+ _metadata_buffer_offset =
+ stdex::make_unique<Offset<MetaDataBufferLink>>(fb, tfl_model->metadata_buffer());
+ model_build();
+}
+
+void CircleModel::model_build(void) const
+{
+ circle::ModelBuilder model_builder{*_fb};
+
+ model_builder.add_version(_version);
+ model_builder.add_description(_description);
+ model_builder.add_operator_codes(_operator_codes_offset->offset());
+ model_builder.add_subgraphs(_subGraphs_offset->offset());
+ model_builder.add_buffers(_buffers_offset->offset());
+ model_builder.add_metadata_buffer(_metadata_buffer_offset->offset());
+
+ auto model = model_builder.Finish();
+ circle::FinishModelBuffer(*_fb, model);
+}
+
+const char *CircleModel::base(void) const
+{
+ return reinterpret_cast<const char *>(_fb->GetBufferPointer());
+}
+
+size_t CircleModel::size(void) const { return _fb->GetSize(); }
+
+} // namespace tflite2circle