/** * Copyright (c) 2019 Samsung Electronics Co., Ltd All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef __INFERENCE_ENGINE_IMPL_TFLite_H__ #define __INFERENCE_ENGINE_IMPL_TFLite_H__ #include #include "tensorflow/contrib/lite/string.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/context.h" #include #include /** * @file inference_engine_tflite_private.h * @brief This file contains the InferenceTFLite class which * provide Tensorflow-lite based inference functionality */ #ifdef LOG_TAG #undef LOG_TAG #endif #define LOG_TAG "INFERENCE_ENGINE_TFLITE" using namespace InferenceEngineInterface::Common; namespace InferenceEngineImpl { namespace TFLiteImpl { class InferenceTFLite : public IInferenceEngineCommon { public: InferenceTFLite(); ~InferenceTFLite(); int SetPrivateData(void *data) override; int SetTargetDevices(int types) override; int Load(std::vector model_paths, inference_model_format_e model_format) override; int GetInputTensorBuffers( std::vector &buffers) override; int GetOutputTensorBuffers( std::vector &buffers) override; int GetInputLayerProperty( inference_engine_layer_property &property) override; int GetOutputLayerProperty( inference_engine_layer_property &property) override; int SetInputLayerProperty( inference_engine_layer_property &property) override; int SetOutputLayerProperty( inference_engine_layer_property &property) override; int GetBackendCapacity(inference_engine_capacity *capacity) override; int Run(std::vector &input_buffers, std::vector &output_buffers) override; private: int SetInterpreterInfo(); std::unique_ptr mInterpreter; std::unique_ptr mFlatBuffModel; std::vector mInputData; std::vector mInputLayer; /**< Input layer name */ std::vector mOutputLayer; /**< Output layer name */ std::vector mInputTensorInfo; std::vector mOutputTensorInfo; std::vector mInputLayerId; std::vector mOutputLayerId; std::vector mInputAttrType; std::vector mOutputAttrType; std::string mConfigFile; std::string mWeightFile; int mTargetTypes; }; } /* InferenceEngineImpl */ } /* TFLiteImpl */ #endif /* __INFERENCE_ENGINE_IMPL_TFLite_H__ */