summaryrefslogtreecommitdiff
path: root/src/inference_engine_tflite_private.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/inference_engine_tflite_private.h')
-rw-r--r--src/inference_engine_tflite_private.h58
1 files changed, 10 insertions, 48 deletions
diff --git a/src/inference_engine_tflite_private.h b/src/inference_engine_tflite_private.h
index 9de8c5b..2c388f7 100644
--- a/src/inference_engine_tflite_private.h
+++ b/src/inference_engine_tflite_private.h
@@ -17,15 +17,13 @@
#ifndef __INFERENCE_ENGINE_IMPL_TFLite_H__
#define __INFERENCE_ENGINE_IMPL_TFLite_H__
-#include <inference_engine_vision.h>
+#include <inference_engine_common.h>
#include "tensorflow/contrib/lite/string.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/context.h"
-#include <opencv2/core.hpp>
-#include <opencv2/imgproc.hpp>
#include <memory>
#include <dlog.h>
@@ -41,17 +39,15 @@
#define LOG_TAG "INFERENCE_ENGINE_TFLITE"
-using namespace InferenceEngineInterface::Vision;
using namespace InferenceEngineInterface::Common;
namespace InferenceEngineImpl {
namespace TFLiteImpl {
-class InferenceTFLite : public IInferenceEngineVision {
+class InferenceTFLite : public IInferenceEngineCommon {
public:
InferenceTFLite(std::string protoFile,
- std::string weightFile,
- std::string userFile);
+ std::string weightFile);
~InferenceTFLite();
@@ -60,19 +56,9 @@ public:
int SetInputTensorParamNode(std::string node = "input") override;
- int SetInputTensorParamInput(int width, int height, int dim, int ch) override;
-
- int SetInputTensorParamNorm(double deviation = 1.0, double mean = 0.0) override;
-
// Output Tensor Params
int SetOutputTensorParam() override;
- int SetOutputTensorParamThresHold(double threshold) override;
-
- int SetOutputTensorParamNumbers(int number) override;
-
- int SetOutputTensorParamType(int type) override;
-
int SetOutputTensorParamNodes(std::vector<std::string> nodes) override;
int SetTargetDevice(inference_target_type_e type) override;
@@ -82,29 +68,18 @@ public:
int CreateInputLayerPassage() override;
- int PrepareInputLayerPassage(inference_input_type_e type) override;
-
- int Run(cv::Mat tensor) override;
-
- int Run(std::vector<float> tensor) override;
+ int GetInputLayerAttrType() override;
- int GetInferenceResult(ImageClassificationResults& results);
+ void * GetInputDataPtr() override;
- int GetInferenceResult(ObjectDetectionResults& results);
+ int SetInputDataBuffer(tensor_t data) override;
- int GetInferenceResult(FaceDetectionResults& results);
+ int Run() override;
- int GetInferenceResult(FacialLandMarkDetectionResults& results);
-
- int GetInferenceResult(std::vector<std::vector<int>>& dimInfo, std::vector<float*>& results);
-
- int GetNumberOfOutputs() override;
+ int Run(std::vector<float> tensor) override;
- void SetUserListName(std::string userList) override;
+ int GetInferenceResult(tensor_t& results);
-public:
- int SetUserFile();
- int setInput(cv::Mat cvImg);
private:
std::unique_ptr<tflite::Interpreter> mInterpreter;
@@ -115,26 +90,13 @@ private:
int mInputLayerId;
std::vector<int> mOutputLayerId;
- int mMatType;
+
TfLiteType mInputAttrType;
void *mInputData;
- cv::Mat mInputBuffer;
-
- int mCh;
- int mDim;
- cv::Size mInputSize;
-
- double mDeviation;
- double mMean;
- double mThreshold;
- int mOutputNumbers;
- cv::Size mSourceSize;
std::string mConfigFile;
std::string mWeightFile;
- std::string mUserFile;
- std::vector<std::string> mUserListName;
};
} /* InferenceEngineImpl */