diff options
Diffstat (limited to 'src/inference_engine_tflite_private.h')
-rw-r--r-- | src/inference_engine_tflite_private.h | 58 |
1 files changed, 10 insertions, 48 deletions
diff --git a/src/inference_engine_tflite_private.h b/src/inference_engine_tflite_private.h index 9de8c5b..2c388f7 100644 --- a/src/inference_engine_tflite_private.h +++ b/src/inference_engine_tflite_private.h @@ -17,15 +17,13 @@ #ifndef __INFERENCE_ENGINE_IMPL_TFLite_H__ #define __INFERENCE_ENGINE_IMPL_TFLite_H__ -#include <inference_engine_vision.h> +#include <inference_engine_common.h> #include "tensorflow/contrib/lite/string.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/context.h" -#include <opencv2/core.hpp> -#include <opencv2/imgproc.hpp> #include <memory> #include <dlog.h> @@ -41,17 +39,15 @@ #define LOG_TAG "INFERENCE_ENGINE_TFLITE" -using namespace InferenceEngineInterface::Vision; using namespace InferenceEngineInterface::Common; namespace InferenceEngineImpl { namespace TFLiteImpl { -class InferenceTFLite : public IInferenceEngineVision { +class InferenceTFLite : public IInferenceEngineCommon { public: InferenceTFLite(std::string protoFile, - std::string weightFile, - std::string userFile); + std::string weightFile); ~InferenceTFLite(); @@ -60,19 +56,9 @@ public: int SetInputTensorParamNode(std::string node = "input") override; - int SetInputTensorParamInput(int width, int height, int dim, int ch) override; - - int SetInputTensorParamNorm(double deviation = 1.0, double mean = 0.0) override; - // Output Tensor Params int SetOutputTensorParam() override; - int SetOutputTensorParamThresHold(double threshold) override; - - int SetOutputTensorParamNumbers(int number) override; - - int SetOutputTensorParamType(int type) override; - int SetOutputTensorParamNodes(std::vector<std::string> nodes) override; int SetTargetDevice(inference_target_type_e type) override; @@ -82,29 +68,18 @@ public: int CreateInputLayerPassage() override; - int PrepareInputLayerPassage(inference_input_type_e type) override; - - int Run(cv::Mat tensor) override; - - int Run(std::vector<float> tensor) override; + int GetInputLayerAttrType() override; - int GetInferenceResult(ImageClassificationResults& results); + void * GetInputDataPtr() override; - int GetInferenceResult(ObjectDetectionResults& results); + int SetInputDataBuffer(tensor_t data) override; - int GetInferenceResult(FaceDetectionResults& results); + int Run() override; - int GetInferenceResult(FacialLandMarkDetectionResults& results); - - int GetInferenceResult(std::vector<std::vector<int>>& dimInfo, std::vector<float*>& results); - - int GetNumberOfOutputs() override; + int Run(std::vector<float> tensor) override; - void SetUserListName(std::string userList) override; + int GetInferenceResult(tensor_t& results); -public: - int SetUserFile(); - int setInput(cv::Mat cvImg); private: std::unique_ptr<tflite::Interpreter> mInterpreter; @@ -115,26 +90,13 @@ private: int mInputLayerId; std::vector<int> mOutputLayerId; - int mMatType; + TfLiteType mInputAttrType; void *mInputData; - cv::Mat mInputBuffer; - - int mCh; - int mDim; - cv::Size mInputSize; - - double mDeviation; - double mMean; - double mThreshold; - int mOutputNumbers; - cv::Size mSourceSize; std::string mConfigFile; std::string mWeightFile; - std::string mUserFile; - std::vector<std::string> mUserListName; }; } /* InferenceEngineImpl */ |