diff options
-rw-r--r-- | CMakeLists.txt | 4 | ||||
-rw-r--r-- | packaging/inference-engine-tflite.spec | 2 | ||||
-rw-r--r-- | src/inference_engine_tflite.cpp | 39 | ||||
-rw-r--r-- | src/inference_engine_tflite_private.h | 13 |
4 files changed, 40 insertions, 18 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 49e6c16..9b70548 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,7 +18,7 @@ FOREACH(flag ${${fw_name}_CFLAGS}) ENDFOREACH(flag) FOREACH(flag ${${fw_name}_LDFLAGS}) - SET(EXTRA_LDFLAGS "${EXTRA_LDFLAGS} ${flag}") + SET(EXTRA_LDFLAGS "${EXTRA_LDFLAGS} ${flag} -lEGL -lGLESv2 -ltensorflowlite -ltensorflowlite_gpu_delegate") ENDFOREACH(flag) #Remove leading whitespace POLICY CMP0004 STRING(REGEX REPLACE "^ " "" EXTRA_LDFLAGS ${EXTRA_LDFLAGS}) @@ -26,7 +26,7 @@ STRING(REGEX REPLACE "^ " "" EXTRA_LDFLAGS ${EXTRA_LDFLAGS}) SET(CMAKE_C_FLAGS "-I./include -I./include/headers ${CMAKE_C_FLAGS} ${EXTRA_CFLAGS} -fPIC -Wall -w") SET(CMAKE_C_FLAGS_DEBUG "-O0 -g") -SET(CMAKE_CXX_FLAGS "-I./include -I./include/headers ${CMAKE_CXX_FLAGS} ${EXTRA_CXXFLAGS} -fPIC") +SET(CMAKE_CXX_FLAGS "-I./include -I./include/headers -I/usr/include/tensorflow2/tensorflow ${CMAKE_CXX_FLAGS} ${EXTRA_CXXFLAGS} -fPIC") SET(CMAKE_CXX_FLAGS_DEBUG "-O0 -g --w") ADD_DEFINITIONS("-DPREFIX=\"${CMAKE_INSTALL_PREFIX}\"") diff --git a/packaging/inference-engine-tflite.spec b/packaging/inference-engine-tflite.spec index ab41e13..086ced9 100644 --- a/packaging/inference-engine-tflite.spec +++ b/packaging/inference-engine-tflite.spec @@ -11,7 +11,9 @@ BuildRequires: cmake BuildRequires: python BuildRequires: pkgconfig(dlog) BuildRequires: pkgconfig(inference-engine-interface-common) +BuildRequires: coregl-devel BuildRequires: tensorflow-lite-devel +BuildRequires: tensorflow2-lite-devel %description Tensorflow-Lite based implementation of inference-engine-interface diff --git a/src/inference_engine_tflite.cpp b/src/inference_engine_tflite.cpp index 78e4f64..4c265fd 100644 --- a/src/inference_engine_tflite.cpp +++ b/src/inference_engine_tflite.cpp @@ -52,6 +52,21 @@ namespace TFLiteImpl { LOGI("ENTER"); + switch (types) { + case INFERENCE_TARGET_CPU: + LOGI("Device type is CPU."); + break; + case INFERENCE_TARGET_GPU: + LOGI("Device type is GPU."); + break; + case INFERENCE_TARGET_CUSTOM: + case INFERENCE_TARGET_NONE: + default: + LOGW("Not supported device type [%d], Set CPU mode", + (int) mTargetTypes); + return INFERENCE_ENGINE_ERROR_INVALID_PARAMETER; + } + mTargetTypes = types; LOGI("LEAVE"); @@ -89,18 +104,18 @@ namespace TFLiteImpl LOGI("Inferece targets are: [%d]", mTargetTypes); - switch (mTargetTypes) { - case INFERENCE_TARGET_CPU: - mInterpreter->UseNNAPI(false); - break; - case INFERENCE_TARGET_GPU: - mInterpreter->UseNNAPI(true); - break; - case INFERENCE_TARGET_CUSTOM: - case INFERENCE_TARGET_NONE: - default: - LOGW("Not supported device type [%d], Set CPU mode", - (int) mTargetTypes); + if (mTargetTypes == INFERENCE_TARGET_GPU) { + TfLiteDelegate *delegate = TfLiteGpuDelegateV2Create(nullptr); + if (!delegate){ + LOGE("Failed to GPU delegate"); + return INFERENCE_ENGINE_ERROR_INVALID_OPERATION; + } + + if (mInterpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) + { + LOGE("Failed to construct GPU delegate"); + return INFERENCE_ENGINE_ERROR_INVALID_OPERATION; + } } mInterpreter->SetNumThreads(MV_INFERENCE_TFLITE_MAX_THREAD_NUM); diff --git a/src/inference_engine_tflite_private.h b/src/inference_engine_tflite_private.h index 0c665e0..ce92ee9 100644 --- a/src/inference_engine_tflite_private.h +++ b/src/inference_engine_tflite_private.h @@ -19,10 +19,15 @@ #include <inference_engine_common.h> -#include "tensorflow/contrib/lite/string.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow1/contrib/lite/string.h" +#include "tensorflow1/contrib/lite/kernels/register.h" +#include "tensorflow1/contrib/lite/model.h" +#include "tensorflow1/contrib/lite/context.h" + +#include "tensorflow2/lite/delegates/gpu/delegate.h" +#include "tensorflow2/lite/kernels/register.h" +#include "tensorflow2/lite/model.h" +#include "tensorflow2/lite/optional_debug_tools.h" #include <memory> #include <dlog.h> |