diff options
author | Inki Dae <inki.dae@samsung.com> | 2020-05-12 07:14:23 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@review> | 2020-05-12 07:14:23 +0000 |
commit | 7dd18f4892581efee237005654d3c90e7a8c0e47 (patch) | |
tree | ed25dc20c0dceaf6151a64561dc882304071db32 | |
parent | 9e9aaef2b324feb977846843a54345dca4ce42be (diff) | |
parent | ce5664914b7d537c505132d6282087ed3e67763a (diff) | |
download | inference-engine-tflite-7dd18f4892581efee237005654d3c90e7a8c0e47.tar.gz inference-engine-tflite-7dd18f4892581efee237005654d3c90e7a8c0e47.tar.bz2 inference-engine-tflite-7dd18f4892581efee237005654d3c90e7a8c0e47.zip |
Merge "Get tensor infos from interpreter" into tizensubmit/tizen/20200602.011936
-rw-r--r-- | packaging/inference-engine-tflite.spec | 2 | ||||
-rw-r--r-- | src/inference_engine_tflite.cpp | 59 | ||||
-rw-r--r-- | src/inference_engine_tflite_private.h | 2 |
3 files changed, 56 insertions, 7 deletions
diff --git a/packaging/inference-engine-tflite.spec b/packaging/inference-engine-tflite.spec index 9f2a99a..d9f9b72 100644 --- a/packaging/inference-engine-tflite.spec +++ b/packaging/inference-engine-tflite.spec @@ -1,7 +1,7 @@ Name: inference-engine-tflite Summary: Tensorflow-Lite based implementation of inference-engine-interface Version: 0.0.1 -Release: 10 +Release: 11 Group: Multimedia/Libraries License: Apache-2.0 Source0: %{name}-%{version}.tar.gz diff --git a/src/inference_engine_tflite.cpp b/src/inference_engine_tflite.cpp index bd44e06..11a7380 100644 --- a/src/inference_engine_tflite.cpp +++ b/src/inference_engine_tflite.cpp @@ -95,6 +95,7 @@ int InferenceTFLite::Load(std::vector<std::string> model_paths, inference_model_ } mInterpreter->SetNumThreads(MV_INFERENCE_TFLITE_MAX_THREAD_NUM); + LOGI("mInterpreter->tensors_size() :[%d]",mInterpreter->tensors_size()); // input tensor if (mInterpreter->inputs().size()) { @@ -151,8 +152,7 @@ int InferenceTFLite::GetInputTensorBuffers(std::vector<inference_engine_tensor_b LOGI("ENTER"); if (mInputTensorInfo.empty()) { - LOGE("InputTensorInfo is empty. Do SetInputLayerProperty first."); - return INFERENCE_ENGINE_ERROR_INVALID_OPERATION; + SetInterpreterInfo(); } mInputData.clear(); @@ -221,14 +221,12 @@ int InferenceTFLite::GetInputLayerProperty(inference_engine_layer_property &prop { LOGI("ENTER"); - if (mInputLayer.empty()) { - return INFERENCE_ENGINE_ERROR_INVALID_OPERATION; - } - + SetInterpreterInfo(); property.layer_names = mInputLayer; property.tensor_infos = mInputTensorInfo; LOGI("LEAVE"); + return INFERENCE_ENGINE_ERROR_NONE; } @@ -317,6 +315,7 @@ int InferenceTFLite::SetInputLayerProperty(inference_engine_layer_property &prop int InferenceTFLite::SetOutputLayerProperty(inference_engine_layer_property &property) { LOGI("ENTER"); + std::vector<std::string>::iterator iter; for (iter = property.layer_names.begin(); iter != property.layer_names.end(); iter++) { std::string name = *iter; @@ -364,6 +363,54 @@ int InferenceTFLite::Run(std::vector<inference_engine_tensor_buffer> &input_buff return INFERENCE_ENGINE_ERROR_NONE; } +int InferenceTFLite::SetInterpreterInfo() +{ + if (mInputLayer.empty() || mInputTensorInfo.empty()) { + LOGI("mInputLayer is empty. layers and tensors that mInterpreter has will be returned."); + + mInputLayer.clear(); + std::vector<std::string>().swap(mInputLayer); + + mInputTensorInfo.clear(); + std::vector<inference_engine_tensor_info>().swap(mInputTensorInfo); + + for (auto iter = mInputLayerId.begin(); iter != mInputLayerId.end(); ++iter) { + mInputLayer.push_back(mInterpreter->tensor((*iter))->name); + + std::vector<size_t> shape_nhwc; + + for (int idx = 0; idx <mInterpreter->tensor((*iter))->dims->size; idx++) { + shape_nhwc.push_back(mInterpreter->tensor((*iter))->dims->data[idx]); + } + + inference_engine_tensor_info tensor_info { + shape_nhwc, INFERENCE_TENSOR_SHAPE_NHWC, INFERENCE_TENSOR_DATA_TYPE_NONE, 1 + }; + + if (mInterpreter->tensor((*iter))->type == kTfLiteUInt8) { + LOGI("type is kTfLiteUInt8"); + tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_UINT8; + } + else if (mInterpreter->tensor((*iter))->type == kTfLiteFloat32) { + LOGI("type is kTfLiteFloat32"); + tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_FLOAT32; + } + else { + LOGE("Not supported"); + return INFERENCE_ENGINE_ERROR_NOT_SUPPORTED_FORMAT; + } + + for (auto iter2 : tensor_info.shape) + { + tensor_info.size *= iter2; + } + mInputTensorInfo.push_back(tensor_info); + } + } + + return INFERENCE_ENGINE_ERROR_NONE; +} + extern "C" { class IInferenceEngineCommon* EngineCommonInit(void) diff --git a/src/inference_engine_tflite_private.h b/src/inference_engine_tflite_private.h index 1ab36c1..bfe2d3b 100644 --- a/src/inference_engine_tflite_private.h +++ b/src/inference_engine_tflite_private.h @@ -71,6 +71,8 @@ public: std::vector<inference_engine_tensor_buffer> &output_buffers) override; private: + int SetInterpreterInfo(); + std::unique_ptr<tflite::Interpreter> mInterpreter; std::unique_ptr<tflite::FlatBufferModel> mFlatBuffModel; std::vector<void *> mInputData; |