summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt4
-rw-r--r--packaging/inference-engine-tflite.spec2
-rw-r--r--src/inference_engine_tflite.cpp39
-rw-r--r--src/inference_engine_tflite_private.h13
4 files changed, 40 insertions, 18 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 49e6c16..9b70548 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -18,7 +18,7 @@ FOREACH(flag ${${fw_name}_CFLAGS})
ENDFOREACH(flag)
FOREACH(flag ${${fw_name}_LDFLAGS})
- SET(EXTRA_LDFLAGS "${EXTRA_LDFLAGS} ${flag}")
+ SET(EXTRA_LDFLAGS "${EXTRA_LDFLAGS} ${flag} -lEGL -lGLESv2 -ltensorflowlite -ltensorflowlite_gpu_delegate")
ENDFOREACH(flag)
#Remove leading whitespace POLICY CMP0004
STRING(REGEX REPLACE "^ " "" EXTRA_LDFLAGS ${EXTRA_LDFLAGS})
@@ -26,7 +26,7 @@ STRING(REGEX REPLACE "^ " "" EXTRA_LDFLAGS ${EXTRA_LDFLAGS})
SET(CMAKE_C_FLAGS "-I./include -I./include/headers ${CMAKE_C_FLAGS} ${EXTRA_CFLAGS} -fPIC -Wall -w")
SET(CMAKE_C_FLAGS_DEBUG "-O0 -g")
-SET(CMAKE_CXX_FLAGS "-I./include -I./include/headers ${CMAKE_CXX_FLAGS} ${EXTRA_CXXFLAGS} -fPIC")
+SET(CMAKE_CXX_FLAGS "-I./include -I./include/headers -I/usr/include/tensorflow2/tensorflow ${CMAKE_CXX_FLAGS} ${EXTRA_CXXFLAGS} -fPIC")
SET(CMAKE_CXX_FLAGS_DEBUG "-O0 -g --w")
ADD_DEFINITIONS("-DPREFIX=\"${CMAKE_INSTALL_PREFIX}\"")
diff --git a/packaging/inference-engine-tflite.spec b/packaging/inference-engine-tflite.spec
index ab41e13..086ced9 100644
--- a/packaging/inference-engine-tflite.spec
+++ b/packaging/inference-engine-tflite.spec
@@ -11,7 +11,9 @@ BuildRequires: cmake
BuildRequires: python
BuildRequires: pkgconfig(dlog)
BuildRequires: pkgconfig(inference-engine-interface-common)
+BuildRequires: coregl-devel
BuildRequires: tensorflow-lite-devel
+BuildRequires: tensorflow2-lite-devel
%description
Tensorflow-Lite based implementation of inference-engine-interface
diff --git a/src/inference_engine_tflite.cpp b/src/inference_engine_tflite.cpp
index 78e4f64..4c265fd 100644
--- a/src/inference_engine_tflite.cpp
+++ b/src/inference_engine_tflite.cpp
@@ -52,6 +52,21 @@ namespace TFLiteImpl
{
LOGI("ENTER");
+ switch (types) {
+ case INFERENCE_TARGET_CPU:
+ LOGI("Device type is CPU.");
+ break;
+ case INFERENCE_TARGET_GPU:
+ LOGI("Device type is GPU.");
+ break;
+ case INFERENCE_TARGET_CUSTOM:
+ case INFERENCE_TARGET_NONE:
+ default:
+ LOGW("Not supported device type [%d], Set CPU mode",
+ (int) mTargetTypes);
+ return INFERENCE_ENGINE_ERROR_INVALID_PARAMETER;
+ }
+
mTargetTypes = types;
LOGI("LEAVE");
@@ -89,18 +104,18 @@ namespace TFLiteImpl
LOGI("Inferece targets are: [%d]", mTargetTypes);
- switch (mTargetTypes) {
- case INFERENCE_TARGET_CPU:
- mInterpreter->UseNNAPI(false);
- break;
- case INFERENCE_TARGET_GPU:
- mInterpreter->UseNNAPI(true);
- break;
- case INFERENCE_TARGET_CUSTOM:
- case INFERENCE_TARGET_NONE:
- default:
- LOGW("Not supported device type [%d], Set CPU mode",
- (int) mTargetTypes);
+ if (mTargetTypes == INFERENCE_TARGET_GPU) {
+ TfLiteDelegate *delegate = TfLiteGpuDelegateV2Create(nullptr);
+ if (!delegate){
+ LOGE("Failed to GPU delegate");
+ return INFERENCE_ENGINE_ERROR_INVALID_OPERATION;
+ }
+
+ if (mInterpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk)
+ {
+ LOGE("Failed to construct GPU delegate");
+ return INFERENCE_ENGINE_ERROR_INVALID_OPERATION;
+ }
}
mInterpreter->SetNumThreads(MV_INFERENCE_TFLITE_MAX_THREAD_NUM);
diff --git a/src/inference_engine_tflite_private.h b/src/inference_engine_tflite_private.h
index 0c665e0..ce92ee9 100644
--- a/src/inference_engine_tflite_private.h
+++ b/src/inference_engine_tflite_private.h
@@ -19,10 +19,15 @@
#include <inference_engine_common.h>
-#include "tensorflow/contrib/lite/string.h"
-#include "tensorflow/contrib/lite/kernels/register.h"
-#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow1/contrib/lite/string.h"
+#include "tensorflow1/contrib/lite/kernels/register.h"
+#include "tensorflow1/contrib/lite/model.h"
+#include "tensorflow1/contrib/lite/context.h"
+
+#include "tensorflow2/lite/delegates/gpu/delegate.h"
+#include "tensorflow2/lite/kernels/register.h"
+#include "tensorflow2/lite/model.h"
+#include "tensorflow2/lite/optional_debug_tools.h"
#include <memory>
#include <dlog.h>