summaryrefslogtreecommitdiff
path: root/src/inference_engine_tflite_private.h
diff options
context:
space:
mode:
authorTae-Young Chung <ty83.chung@samsung.com>2019-09-18 17:36:11 +0900
committerTae-Young Chung <ty83.chung@samsung.com>2019-09-23 15:09:15 +0900
commit9d443679de2152941e805ca04840b58da838ae9c (patch)
tree84758d441e04ec254d2112de54430a934e9f64cc /src/inference_engine_tflite_private.h
parent2b34612a686057e0b80e3ffd17cc7f36020355f7 (diff)
downloadinference-engine-tflite-9d443679de2152941e805ca04840b58da838ae9c.tar.gz
inference-engine-tflite-9d443679de2152941e805ca04840b58da838ae9c.tar.bz2
inference-engine-tflite-9d443679de2152941e805ca04840b58da838ae9c.zip
inference-engine-tflite is a plugin to provide inference only. Thus, domain specific functions such as vision should be removed. Instead, add apis GetInputLayerAttrType(), SetInputDataBuffer(), and GetInputDataPtr() which can be used to access to memory. Change-Id: I408a95c86bc2477465e5a08dab192bb6f3813ad1 Signed-off-by: Tae-Young Chung <ty83.chung@samsung.com>
Diffstat (limited to 'src/inference_engine_tflite_private.h')
-rw-r--r--src/inference_engine_tflite_private.h58
1 files changed, 10 insertions, 48 deletions
diff --git a/src/inference_engine_tflite_private.h b/src/inference_engine_tflite_private.h
index 9de8c5b..2c388f7 100644
--- a/src/inference_engine_tflite_private.h
+++ b/src/inference_engine_tflite_private.h
@@ -17,15 +17,13 @@
#ifndef __INFERENCE_ENGINE_IMPL_TFLite_H__
#define __INFERENCE_ENGINE_IMPL_TFLite_H__
-#include <inference_engine_vision.h>
+#include <inference_engine_common.h>
#include "tensorflow/contrib/lite/string.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/context.h"
-#include <opencv2/core.hpp>
-#include <opencv2/imgproc.hpp>
#include <memory>
#include <dlog.h>
@@ -41,17 +39,15 @@
#define LOG_TAG "INFERENCE_ENGINE_TFLITE"
-using namespace InferenceEngineInterface::Vision;
using namespace InferenceEngineInterface::Common;
namespace InferenceEngineImpl {
namespace TFLiteImpl {
-class InferenceTFLite : public IInferenceEngineVision {
+class InferenceTFLite : public IInferenceEngineCommon {
public:
InferenceTFLite(std::string protoFile,
- std::string weightFile,
- std::string userFile);
+ std::string weightFile);
~InferenceTFLite();
@@ -60,19 +56,9 @@ public:
int SetInputTensorParamNode(std::string node = "input") override;
- int SetInputTensorParamInput(int width, int height, int dim, int ch) override;
-
- int SetInputTensorParamNorm(double deviation = 1.0, double mean = 0.0) override;
-
// Output Tensor Params
int SetOutputTensorParam() override;
- int SetOutputTensorParamThresHold(double threshold) override;
-
- int SetOutputTensorParamNumbers(int number) override;
-
- int SetOutputTensorParamType(int type) override;
-
int SetOutputTensorParamNodes(std::vector<std::string> nodes) override;
int SetTargetDevice(inference_target_type_e type) override;
@@ -82,29 +68,18 @@ public:
int CreateInputLayerPassage() override;
- int PrepareInputLayerPassage(inference_input_type_e type) override;
-
- int Run(cv::Mat tensor) override;
-
- int Run(std::vector<float> tensor) override;
+ int GetInputLayerAttrType() override;
- int GetInferenceResult(ImageClassificationResults& results);
+ void * GetInputDataPtr() override;
- int GetInferenceResult(ObjectDetectionResults& results);
+ int SetInputDataBuffer(tensor_t data) override;
- int GetInferenceResult(FaceDetectionResults& results);
+ int Run() override;
- int GetInferenceResult(FacialLandMarkDetectionResults& results);
-
- int GetInferenceResult(std::vector<std::vector<int>>& dimInfo, std::vector<float*>& results);
-
- int GetNumberOfOutputs() override;
+ int Run(std::vector<float> tensor) override;
- void SetUserListName(std::string userList) override;
+ int GetInferenceResult(tensor_t& results);
-public:
- int SetUserFile();
- int setInput(cv::Mat cvImg);
private:
std::unique_ptr<tflite::Interpreter> mInterpreter;
@@ -115,26 +90,13 @@ private:
int mInputLayerId;
std::vector<int> mOutputLayerId;
- int mMatType;
+
TfLiteType mInputAttrType;
void *mInputData;
- cv::Mat mInputBuffer;
-
- int mCh;
- int mDim;
- cv::Size mInputSize;
-
- double mDeviation;
- double mMean;
- double mThreshold;
- int mOutputNumbers;
- cv::Size mSourceSize;
std::string mConfigFile;
std::string mWeightFile;
- std::string mUserFile;
- std::vector<std::string> mUserListName;
};
} /* InferenceEngineImpl */