diff options
-rw-r--r-- | compiler/exo-tflite/src/Dialect/IR/TFLNodes.cpp | 48 | ||||
-rw-r--r-- | compiler/exo-tflite/src/Dialect/IR/TFLNodes.h | 22 | ||||
-rw-r--r-- | compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst | 1 | ||||
-rw-r--r-- | compiler/exo-tflite/src/TFLFormattedGraph.cpp | 6 |
4 files changed, 76 insertions, 1 deletions
diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.cpp b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.cpp index 12f57819e..7287fc505 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.cpp +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.cpp @@ -16,4 +16,50 @@ #include "TFLNodes.h" -// This file exists to make compiler build TFLNodes.h +#include <loco.h> + +#include <cassert> + +namespace locoex +{ + +template <loco::DataType DT> uint32_t TFLConst::size(void) const +{ + assert(dtype() == DT); + assert(_data.size() % sizeof(typename loco::DataTypeImpl<DT>::Type) == 0); + return _data.size() / sizeof(typename loco::DataTypeImpl<DT>::Type); +} + +template <loco::DataType DT> void TFLConst::size(uint32_t l) +{ + assert(dtype() == DT); + _data.resize(l * sizeof(typename loco::DataTypeImpl<DT>::Type)); +} + +template <loco::DataType DT> +const typename loco::DataTypeImpl<DT>::Type &TFLConst::at(uint32_t n) const +{ + assert(dtype() == DT); + assert(n < size<DT>()); + return *(reinterpret_cast<const typename loco::DataTypeImpl<DT>::Type *>(_data.data()) + n); +} + +template <loco::DataType DT> typename loco::DataTypeImpl<DT>::Type &TFLConst::at(uint32_t n) +{ + assert(dtype() == DT); + assert(n < size<DT>()); + return *(reinterpret_cast<typename loco::DataTypeImpl<DT>::Type *>(_data.data()) + n); +} + +#define INSTANTIATE(DT) \ + template uint32_t TFLConst::size<DT>(void) const; \ + template void TFLConst::size<DT>(uint32_t); \ + template const typename loco::DataTypeImpl<DT>::Type &TFLConst::at<DT>(uint32_t) const; \ + template typename loco::DataTypeImpl<DT>::Type &TFLConst::at<DT>(uint32_t); + +INSTANTIATE(loco::DataType::S32); +INSTANTIATE(loco::DataType::FLOAT32); + +#undef INSTANTIATE + +} // namespace locoex diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h index 42de74806..f12b62d89 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h @@ -22,6 +22,7 @@ #include <loco/IR/Node.h> #include <loco/IR/NodeMixins.h> +#include <loco/IR/DataTypeTraits.h> #include <array> @@ -162,6 +163,27 @@ private: // TODO TFLConcatenation +/** + * @brief Class to build tensor data + * @note This will not be exported as a specific op + */ +class TFLConst final : public FixedArityNode<0, TFLNodeImpl<TFLOpcode::NOP_CONSTGEN>>, + public loco::NodeMixin<loco::NodeTrait::DataType>, + public loco::NodeMixin<loco::NodeTrait::TensorShape> +{ +public: + TFLConst() = default; + +public: + template <loco::DataType DT> uint32_t size(void) const; + template <loco::DataType DT> void size(uint32_t size); + template <loco::DataType DT> const typename loco::DataTypeImpl<DT>::Type &at(uint32_t n) const; + template <loco::DataType DT> typename loco::DataTypeImpl<DT>::Type &at(uint32_t n); + +private: + std::vector<uint8_t> _data; +}; + // TODO TFLConv2D // TODO TFLDepthwiseConv2D diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst index a77ac80eb..9e9c410f1 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst @@ -8,6 +8,7 @@ TFL_NODE(ADD, locoex::TFLAdd) TFL_NODE(AVERAGE_POOL_2D, locoex::TFLAveragePool2D) // TODO TFLConcatenation +TFL_NODE(NOP_CONSTGEN, locoex::TFLConst) // TODO TFLConv2D // TODO TFLDepthwiseConv2D TFL_NODE(DIV, locoex::TFLDiv) diff --git a/compiler/exo-tflite/src/TFLFormattedGraph.cpp b/compiler/exo-tflite/src/TFLFormattedGraph.cpp index 52ce0ff34..d01fe7bcc 100644 --- a/compiler/exo-tflite/src/TFLFormattedGraph.cpp +++ b/compiler/exo-tflite/src/TFLFormattedGraph.cpp @@ -127,6 +127,12 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLAveragePool2D *node, // TODO TFLConcatenation +bool TFLNodeSummaryBuilder::summary(const locoex::TFLConst *, locop::NodeSummary &s) const +{ + s.state(locop::NodeSummary::State::PartiallyKnown); + return true; +} + // TODO TFLConv2D // TODO TFLDepthwiseConv2D |