diff options
Diffstat (limited to 'compiler/tflchef/tflite/src/TFliteImport.h')
-rw-r--r-- | compiler/tflchef/tflite/src/TFliteImport.h | 66 |
1 files changed, 54 insertions, 12 deletions
diff --git a/compiler/tflchef/tflite/src/TFliteImport.h b/compiler/tflchef/tflite/src/TFliteImport.h index 9d0a642ab..ade8fc810 100644 --- a/compiler/tflchef/tflite/src/TFliteImport.h +++ b/compiler/tflchef/tflite/src/TFliteImport.h @@ -17,9 +17,7 @@ #ifndef __TFLITE_IMPORT_H__ #define __TFLITE_IMPORT_H__ -#include <mio/tflite/schema_generated.h> - -#include <souschef/TensorFiller.h> +#include <tflite_generated.h> #include <tflchef.pb.h> @@ -42,7 +40,7 @@ bool is_custom(const tflite::OperatorCode *opcode); /** * @brief Loads TF lite file and provides helpers to access attributes */ -class TFliteImport : public souschef::TensorFiller +class TFliteImport { public: TFliteImport(const tflite::Model *model); @@ -65,15 +63,59 @@ public: std::string opcode_name(const tflite::Operator *op) const; size_t buffer_info(const tflite::Tensor *tensor, const uint8_t **buff_data); + /** + * @brief This will record the tensor by index, if it needs filler option, + * such as kernel, bias. + */ + void set_tensor_filler(uint32_t tensor_index) { _tensor_filler[tensor_index] = true; } + + /** + * @brief This will store int32 filler values such as reshape information for the tensor + */ + void set_tensor_filler(uint32_t tensor_index, std::vector<int32_t> &expvalues) + { + _tensor_filler_vint32[tensor_index] = expvalues; + } + + /** + * @brief This will return true if the tensor by index, needs a filler option. + */ + bool get_tensor_filler(uint32_t tensor_index) + { + auto it = _tensor_filler.find(tensor_index); + if (it != _tensor_filler.end()) + { + return it->second; + } + return false; + } + + /** + * @brief This will return true if the tensor by index, needs a int array filler option. + */ + bool get_tensor_filler(uint32_t tensor_index, std::vector<int32_t> &expvalues) + { + auto it = _tensor_filler_vint32.find(tensor_index); + if (it != _tensor_filler_vint32.end()) + { + expvalues = it->second; + return true; + } + return false; + } + private: - const TFliteSubGraphs_t *_subgraphs{nullptr}; - const TFliteBuffers_t *_buffers{nullptr}; - const TFliteTensors_t *_tensors{nullptr}; - const TFliteOperators_t *_operators{nullptr}; - - std::vector<const tflite::OperatorCode *> _op_codes{}; - std::vector<int32_t> _inputs{}; - std::vector<int32_t> _outputs{}; + const TFliteSubGraphs_t *_subgraphs; + const TFliteBuffers_t *_buffers; + const TFliteTensors_t *_tensors; + const TFliteOperators_t *_operators; + + std::vector<const tflite::OperatorCode *> _op_codes; + std::vector<int32_t> _inputs; + std::vector<int32_t> _outputs; + + std::map<uint32_t, bool> _tensor_filler; + std::map<uint32_t, std::vector<int32_t>> _tensor_filler_vint32; }; } // namespace tflchef |