diff options
Diffstat (limited to 'compiler/tflchef/tflite/src/TFliteImport.h')
-rw-r--r-- | compiler/tflchef/tflite/src/TFliteImport.h | 65 |
1 files changed, 3 insertions, 62 deletions
diff --git a/compiler/tflchef/tflite/src/TFliteImport.h b/compiler/tflchef/tflite/src/TFliteImport.h index 5b46f4501..9d0a642ab 100644 --- a/compiler/tflchef/tflite/src/TFliteImport.h +++ b/compiler/tflchef/tflite/src/TFliteImport.h @@ -19,6 +19,8 @@ #include <mio/tflite/schema_generated.h> +#include <souschef/TensorFiller.h> + #include <tflchef.pb.h> #include <map> @@ -40,7 +42,7 @@ bool is_custom(const tflite::OperatorCode *opcode); /** * @brief Loads TF lite file and provides helpers to access attributes */ -class TFliteImport +class TFliteImport : public souschef::TensorFiller { public: TFliteImport(const tflite::Model *model); @@ -63,63 +65,6 @@ public: std::string opcode_name(const tflite::Operator *op) const; size_t buffer_info(const tflite::Tensor *tensor, const uint8_t **buff_data); - /** - * @brief This will record the tensor by index, if it needs filler option, - * such as kernel, bias. - */ - void set_tensor_filler(uint32_t tensor_index) { _tensor_filler[tensor_index] = true; } - - /** - * @brief This will store int32 filler values such as reshape information for the tensor - */ - void set_tensor_filler(uint32_t tensor_index, std::vector<int32_t> &expvalues) - { - _tensor_filler_vint32[tensor_index] = expvalues; - } - - void set_tensor_filler(uint32_t tensor_index, std::vector<float> &expvalues) - { - _tensor_filler_vfloat[tensor_index] = expvalues; - } - - /** - * @brief This will return true if the tensor by index, needs a filler option. - */ - bool get_tensor_filler(uint32_t tensor_index) - { - auto it = _tensor_filler.find(tensor_index); - if (it != _tensor_filler.end()) - { - return it->second; - } - return false; - } - - /** - * @brief This will return true if the tensor by index, needs a int array filler option. - */ - bool get_tensor_filler(uint32_t tensor_index, std::vector<int32_t> &expvalues) - { - auto it = _tensor_filler_vint32.find(tensor_index); - if (it != _tensor_filler_vint32.end()) - { - expvalues = it->second; - return true; - } - return false; - } - - bool get_tensor_filler(uint32_t tensor_index, std::vector<float> &expvalues) - { - auto it = _tensor_filler_vfloat.find(tensor_index); - if (it != _tensor_filler_vfloat.end()) - { - expvalues = it->second; - return true; - } - return false; - } - private: const TFliteSubGraphs_t *_subgraphs{nullptr}; const TFliteBuffers_t *_buffers{nullptr}; @@ -129,10 +74,6 @@ private: std::vector<const tflite::OperatorCode *> _op_codes{}; std::vector<int32_t> _inputs{}; std::vector<int32_t> _outputs{}; - - std::map<uint32_t, bool> _tensor_filler{}; - std::map<uint32_t, std::vector<int32_t>> _tensor_filler_vint32{}; - std::map<uint32_t, std::vector<float>> _tensor_filler_vfloat{}; }; } // namespace tflchef |