summaryrefslogtreecommitdiff
path: root/compiler/luci/import/src/CircleReader.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/import/src/CircleReader.cpp')
-rw-r--r--compiler/luci/import/src/CircleReader.cpp295
1 files changed, 222 insertions, 73 deletions
diff --git a/compiler/luci/import/src/CircleReader.cpp b/compiler/luci/import/src/CircleReader.cpp
index bc7f39762..a42c3f913 100644
--- a/compiler/luci/import/src/CircleReader.cpp
+++ b/compiler/luci/import/src/CircleReader.cpp
@@ -16,6 +16,9 @@
#include "luci/Import/CircleReader.h"
+#include <mio_circle/Helper.h>
+
+#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
@@ -23,52 +26,20 @@
namespace luci
{
-bool is_valid(const circle::OperatorCodeT &opcode)
-{
- circle::BuiltinOperator code = opcode.builtin_code;
- return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
-}
-
-bool is_custom(const circle::OperatorCodeT &opcode)
-{
- circle::BuiltinOperator code = opcode.builtin_code;
- return (code == circle::BuiltinOperator_CUSTOM);
-}
-
-std::string opcode_name(const circle::OperatorCodeT &opcode)
-{
- if (!is_valid(opcode))
- {
- std::ostringstream oss;
- oss << "(invalid)";
- return oss.str();
- }
-
- if (is_custom(opcode))
- {
- if (opcode.custom_code.empty())
- return "(invalid custom)";
-
- return opcode.custom_code;
- }
-
- circle::BuiltinOperator code = opcode.builtin_code;
- return circle::EnumNameBuiltinOperator(code);
-}
-
-const char *tensor_name(const circle::TensorT &tensor)
+const char *tensor_name(const circle::Tensor *tensor)
{
- static const char *kEmptyTensorName = "(noname)";
+ assert(tensor != nullptr);
- if (!tensor.name.empty())
- return tensor.name.c_str();
+ if (tensor->name() == nullptr || std::string(tensor->name()->c_str()).empty())
+ return "(noname)";
- return kEmptyTensorName;
+ return tensor->name()->c_str();
}
-const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor)
+const circle::QuantizationParameters *tensor_quantization(const circle::Tensor *tensor)
{
- return tensor.quantization.get();
+ assert(tensor != nullptr);
+ return tensor->quantization();
}
loco::DataType luci_datatype(const circle::TensorType type)
@@ -86,7 +57,7 @@ loco::DataType luci_datatype(const circle::TensorType type)
case circle::TensorType_INT64:
return loco::DataType::S64;
case circle::TensorType_STRING:
- break;
+ return loco::DataType::STRING;
case circle::TensorType_BOOL:
return loco::DataType::BOOL;
case circle::TensorType_INT16:
@@ -115,7 +86,9 @@ FusedActFunc luci_actfunc(const circle::ActivationFunctionType type)
case circle::ActivationFunctionType::ActivationFunctionType_RELU6:
return luci::FusedActFunc::RELU6;
case circle::ActivationFunctionType::ActivationFunctionType_TANH:
- break;
+ return luci::FusedActFunc::TANH;
+ case circle::ActivationFunctionType::ActivationFunctionType_SIGN_BIT:
+ return luci::FusedActFunc::SIGN_BIT;
default:
break;
}
@@ -149,6 +122,65 @@ MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode)
return MirrorPadMode::UNDEFINED;
}
+luci::CircleFullyConnected::WeightsFormat
+luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format)
+{
+ switch (weights_format)
+ {
+ case circle::FullyConnectedOptionsWeightsFormat_DEFAULT:
+ return luci::CircleFullyConnected::WeightsFormat::DEFAULT;
+ case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
+ return luci::CircleFullyConnected::WeightsFormat::SHUFFLED4x16INT8;
+ case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED16x1FLOAT32:
+ return luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32;
+ default:
+ throw std::runtime_error("Invalid FullyConnectedOptionsWeightsFormat");
+ }
+}
+
+DimensionType luci_dim_type(const circle::DimensionType dim_type)
+{
+ switch (dim_type)
+ {
+ case circle::DimensionType_DENSE:
+ return DimensionType::DENSE;
+ case circle::DimensionType_SPARSE_CSR:
+ return DimensionType::SPARSE_CSR;
+ default:
+ throw std::runtime_error("Invalid DimensionType");
+ }
+}
+
+SparseIndexVector
+luci_sparse_index_vector(const circle::SparseIndexVectorUnion &sparse_index_vector)
+{
+ switch (sparse_index_vector.type)
+ {
+ case circle::SparseIndexVector_NONE:
+ return SparseIndexVector{SparseIndexVectorType::NONE, nullptr};
+ case circle::SparseIndexVector_Int32Vector:
+ {
+ const auto const_vec_ptr =
+ static_cast<const void *>(&(sparse_index_vector.AsInt32Vector()->values));
+ return SparseIndexVector{SparseIndexVectorType::I32, const_vec_ptr};
+ }
+ case circle::SparseIndexVector_Uint16Vector:
+ {
+ const auto const_vec_ptr =
+ static_cast<const void *>(&(sparse_index_vector.AsUint16Vector()->values));
+ return SparseIndexVector{SparseIndexVectorType::U16, const_vec_ptr};
+ }
+ case circle::SparseIndexVector_Uint8Vector:
+ {
+ const auto const_vec_ptr =
+ static_cast<const void *>(&(sparse_index_vector.AsUint8Vector()->values));
+ return SparseIndexVector{SparseIndexVectorType::U8, const_vec_ptr};
+ }
+ default:
+ throw std::runtime_error("Invalid SparseIndexVector type");
+ }
+}
+
std::unique_ptr<CircleQuantParam>
luci_quantparam(const circle::QuantizationParametersT *quantization)
{
@@ -174,83 +206,200 @@ luci_quantparam(const circle::QuantizationParametersT *quantization)
return nullptr;
}
-void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
+std::unique_ptr<CircleQuantParam> luci_quantparam(const circle::QuantizationParameters *qparams)
+{
+ // create temporary unpacked API object
+ assert(qparams != nullptr);
+ circle::QuantizationParametersT quantization;
+ qparams->UnPackTo(&quantization);
+
+ return luci_quantparam(&quantization);
+}
+
+std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParametersT *sparsity)
{
+ assert(sparsity);
+ const auto &traversal_order = sparsity->traversal_order;
+ const auto &block_map = sparsity->block_map;
+ const auto &dim_metadata = sparsity->dim_metadata;
+
+ // TODO find a condition that should return nullptr
+ auto sparsityparam = std::make_unique<SparsityParam>();
+
+ sparsityparam->traversal_order = traversal_order;
+ sparsityparam->block_map = block_map;
+ for (const auto &dm : dim_metadata)
+ {
+ sparsityparam->dim_metadata.emplace_back(luci_dim_type(dm->format), dm->dense_size,
+ luci_sparse_index_vector(dm->array_segments),
+ luci_sparse_index_vector(dm->array_indices));
+ }
+
+ return sparsityparam;
+}
+
+std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParameters *sparparam)
+{
+ // create temporary unpacked API object
+ assert(sparparam != nullptr);
+ circle::SparsityParametersT sparsity;
+ sparparam->UnPackTo(&sparsity);
+
+ return luci_sparsityparam(&sparsity);
+}
+
+void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node)
+{
+ assert(tensor != nullptr);
+
node->name(tensor_name(tensor));
- node->dtype(luci_datatype(tensor.type));
+ node->dtype(luci_datatype(tensor->type()));
+
+ const auto tensor_shape_signature = wrap(tensor->shape_signature());
+ const auto tensor_shape = wrap(tensor->shape());
+ assert(tensor_shape_signature.size() == 0 ||
+ tensor_shape_signature.size() == tensor_shape.size());
- std::vector<int32_t> dims = tensor.shape; // in NHWC
+ const auto dims = tensor_shape; // in NHWC
node->rank(dims.size());
for (uint32_t r = 0; r < dims.size(); ++r)
{
- node->dim(r) = loco::Dimension(dims[r]);
+ if (tensor_shape_signature.size() > 0 && tensor_shape_signature.at(r) == -1)
+ node->dim(r).unset();
+ else
+ node->dim(r).set(dims[r]);
}
- const auto *quantization = tensor.quantization.get();
+ const auto quantization = tensor->quantization();
if (quantization != nullptr)
{
auto quantparam = luci_quantparam(quantization);
if (quantparam)
node->quantparam(std::move(quantparam));
}
+
+ const auto sparsity = tensor->sparsity();
+ if (sparsity != nullptr)
+ {
+ auto sparsityparam = luci_sparsityparam(sparsity);
+ if (sparsityparam)
+ node->sparsityparam(std::move(sparsityparam));
+ }
}
-circle::BuiltinOperator CircleReader::builtin_code(const circle::OperatorT &op) const
+std::string fb_string2std_string(const flatbuffers::String *fb_str)
{
- const auto &op_codes = opcodes();
- uint32_t index = op.opcode_index;
+ return fb_str == nullptr ? "" : fb_str->str();
+}
+
+circle::BuiltinOperator CircleReader::builtin_code(const circle::Operator *op) const
+{
+ assert(op != nullptr);
+
+ const auto op_codes = opcodes();
+ uint32_t index = op->opcode_index();
assert(index < op_codes.size());
- const circle::OperatorCodeT &opcode = *op_codes[index];
+ const auto opcode = op_codes[index];
+ assert(opcode != nullptr);
- return opcode.builtin_code;
+ return mio::circle::builtin_code_neutral(opcode);
}
-std::string CircleReader::opcode_name(const circle::OperatorT &op) const
+std::string CircleReader::opcode_name(const circle::Operator *op) const
{
- const auto &op_codes = opcodes();
- uint32_t index = op.opcode_index;
- assert(index < op_codes.size());
- const circle::OperatorCodeT &opcode = *op_codes[index];
+ assert(op != nullptr);
- if (!is_valid(opcode))
- {
- std::ostringstream oss;
- oss << "(invalid: " << index << ")";
- return oss.str();
- }
+ const auto op_codes = opcodes();
+ uint32_t index = op->opcode_index();
+ assert(index < op_codes.size());
+ const auto opcode = op_codes[index];
- return ::luci::opcode_name(opcode);
+ return mio::circle::opcode_name(opcode);
}
bool CircleReader::parse(const circle::Model *model)
{
assert(model != nullptr);
- _model.reset(model->UnPack());
-
// for direct pointer access
- _model_ptr = model;
+ _model = model;
return true;
}
bool CircleReader::select_subgraph(uint32_t sgindex)
{
- if (_model->subgraphs.size() <= sgindex)
+ if (num_subgraph() <= sgindex)
{
assert(false);
return false;
}
- _current_subgraph = _model->subgraphs[sgindex].get();
-
// for direct pointer access
- auto subgraphs = _model_ptr->subgraphs();
- const circle::SubGraph *subgraph = (*subgraphs)[sgindex];
+ auto subgraphs = _model->subgraphs();
+ assert(subgraphs != nullptr);
- _tensors_ptr = subgraph->tensors();
+ _current_subgraph = subgraphs->Get(sgindex);
+ assert(_current_subgraph != nullptr);
return true;
}
+template <typename T>
+VectorWrapper<T>::VectorWrapper(const flatbuffers::Vector<T> *ptr) : _vector(ptr)
+{
+ // Do nothing
+}
+
+template <typename T> uint32_t VectorWrapper<T>::size() const
+{
+ return null() ? 0 : _vector->size();
+}
+
+template <typename T> const T *VectorWrapper<T>::data() const
+{
+ return null() ? nullptr : _vector->data();
+}
+
+template <typename T> typename VectorWrapper<T>::iterator VectorWrapper<T>::begin() const
+{
+ return null() ? iterator(nullptr, 0) : _vector->begin();
+}
+
+template <typename T> typename VectorWrapper<T>::iterator VectorWrapper<T>::end() const
+{
+ return null() ? begin() : _vector->end();
+}
+
+template <typename T> typename VectorWrapper<T>::value_type VectorWrapper<T>::at(uint32_t i) const
+{
+ if (i >= size())
+ {
+ // TODO find better error message
+ throw std::range_error("Access to prohibited vector element");
+ }
+
+ return _vector->Get(i);
+}
+
+template <typename T>
+typename VectorWrapper<T>::value_type VectorWrapper<T>::operator[](uint32_t i) const
+{
+ return at(i);
+}
+
+template <typename T> bool VectorWrapper<T>::null() const { return _vector == nullptr; }
+template <typename T> bool VectorWrapper<T>::empty() const { return size() == 0; }
+
+#define REGISTER_WRAPPER(T) template class VectorWrapper<T>
+REGISTER_WRAPPER(flatbuffers::Offset<circle::SubGraph>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::Buffer>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::Tensor>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::Operator>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::OperatorCode>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::Metadata>);
+REGISTER_WRAPPER(int32_t);
+REGISTER_WRAPPER(uint8_t);
+#undef REGISTER_WRAPPER
+
} // namespace luci