diff options
Diffstat (limited to 'mv_machine_learning/object_detection/src/ObjectDetection.cpp')
-rw-r--r-- | mv_machine_learning/object_detection/src/ObjectDetection.cpp | 342 |
1 files changed, 342 insertions, 0 deletions
diff --git a/mv_machine_learning/object_detection/src/ObjectDetection.cpp b/mv_machine_learning/object_detection/src/ObjectDetection.cpp new file mode 100644 index 00000000..ea31f322 --- /dev/null +++ b/mv_machine_learning/object_detection/src/ObjectDetection.cpp @@ -0,0 +1,342 @@ +/** + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <string.h> +#include <fstream> +#include <map> +#include <memory> + +#include "MvMlException.h" +#include "common.h" +#include "mv_object_detection_config.h" +#include "ObjectDetection.h" + +using namespace std; +using namespace mediavision::inference; +using namespace MediaVision::Common; +using namespace mediavision::common; +using namespace mediavision::machine_learning::exception; + +namespace mediavision +{ +namespace machine_learning +{ +template<typename T> +ObjectDetection<T>::ObjectDetection(ObjectDetectionTaskType task_type, shared_ptr<Config> config) + : _task_type(task_type), _config(config) +{ + _inference = make_unique<Inference>(); +} + +template<typename T> void ObjectDetection<T>::preDestroy() +{ + if (!_async_manager) + return; + + _async_manager->stop(); +} + +template<typename T> ObjectDetectionTaskType ObjectDetection<T>::getTaskType() +{ + return _task_type; +} + +template<typename T> void ObjectDetection<T>::getEngineList() +{ + for (auto idx = MV_INFERENCE_BACKEND_NONE + 1; idx < MV_INFERENCE_BACKEND_MAX; ++idx) { + auto backend = _inference->getSupportedInferenceBackend(idx); + // TODO. we need to describe what inference engines are supported by each Task API, + // and based on it, below inference engine types should be checked + // if a given type is supported by this Task API later. As of now, tflite only. + if (backend.second == true && backend.first.compare("tflite") == 0) + _valid_backends.push_back(backend.first); + } +} + +template<typename T> void ObjectDetection<T>::getDeviceList(const string &engine_type) +{ + // TODO. add device types available for a given engine type later. + // In default, cpu and gpu only. + _valid_devices.push_back("cpu"); + _valid_devices.push_back("gpu"); +} + +template<typename T> void ObjectDetection<T>::setEngineInfo(std::string engine_type_name, std::string device_type_name) +{ + if (engine_type_name.empty() || device_type_name.empty()) + throw InvalidParameter("Invalid engine info."); + + transform(engine_type_name.begin(), engine_type_name.end(), engine_type_name.begin(), ::toupper); + transform(device_type_name.begin(), device_type_name.end(), device_type_name.begin(), ::toupper); + + int engine_type = GetBackendType(engine_type_name); + int device_type = GetDeviceType(device_type_name); + + if (engine_type == MEDIA_VISION_ERROR_INVALID_PARAMETER || device_type == MEDIA_VISION_ERROR_INVALID_PARAMETER) + throw InvalidParameter("backend or target device type not found."); + + _config->setBackendType(engine_type); + _config->setTargetDeviceType(device_type); + + LOGI("Engine type : %s => %d, Device type : %s => %d", engine_type_name.c_str(), engine_type, + device_type_name.c_str(), device_type); +} + +template<typename T> unsigned int ObjectDetection<T>::getNumberOfEngines() +{ + if (!_valid_backends.empty()) { + return _valid_backends.size(); + } + + getEngineList(); + return _valid_backends.size(); +} + +template<typename T> const string &ObjectDetection<T>::getEngineType(unsigned int engine_index) +{ + if (!_valid_backends.empty()) { + if (_valid_backends.size() <= engine_index) + throw InvalidParameter("Invalid engine index."); + + return _valid_backends[engine_index]; + } + + getEngineList(); + + if (_valid_backends.size() <= engine_index) + throw InvalidParameter("Invalid engine index."); + + return _valid_backends[engine_index]; +} + +template<typename T> unsigned int ObjectDetection<T>::getNumberOfDevices(const string &engine_type) +{ + if (!_valid_devices.empty()) { + return _valid_devices.size(); + } + + getDeviceList(engine_type); + return _valid_devices.size(); +} + +template<typename T> +const string &ObjectDetection<T>::getDeviceType(const string &engine_type, unsigned int device_index) +{ + if (!_valid_devices.empty()) { + if (_valid_devices.size() <= device_index) + throw InvalidParameter("Invalid device index."); + + return _valid_devices[device_index]; + } + + getDeviceList(engine_type); + + if (_valid_devices.size() <= device_index) + throw InvalidParameter("Invalid device index."); + + return _valid_devices[device_index]; +} + +template<typename T> void ObjectDetection<T>::loadLabel() +{ + if (_config->getLabelFilePath().empty()) + return; + + ifstream readFile; + + _labels.clear(); + readFile.open(_config->getLabelFilePath().c_str()); + + if (readFile.fail()) + throw InvalidOperation("Fail to open " + _config->getLabelFilePath() + " file."); + + string line; + + while (getline(readFile, line)) + _labels.push_back(line); + + readFile.close(); +} + +template<typename T> void ObjectDetection<T>::configure() +{ + loadLabel(); + + int ret = _inference->bind(_config->getBackendType(), _config->getTargetDeviceType()); + if (ret != MEDIA_VISION_ERROR_NONE) + throw InvalidOperation("Fail to bind a backend engine."); +} + +template<typename T> void ObjectDetection<T>::prepare() +{ + int ret = _inference->configureInputMetaInfo(_config->getInputMetaMap()); + if (ret != MEDIA_VISION_ERROR_NONE) + throw InvalidOperation("Fail to configure input tensor info from meta file."); + + ret = _inference->configureOutputMetaInfo(_config->getOutputMetaMap()); + if (ret != MEDIA_VISION_ERROR_NONE) + throw InvalidOperation("Fail to configure output tensor info from meta file."); + + _inference->configureModelFiles("", _config->getModelFilePath(), ""); + + // Request to load model files to a backend engine. + ret = _inference->load(); + if (ret != MEDIA_VISION_ERROR_NONE) + throw InvalidOperation("Fail to load model files."); +} + +template<typename T> shared_ptr<MetaInfo> ObjectDetection<T>::getInputMetaInfo() +{ + TensorBuffer &tensor_buffer = _inference->getInputTensorBuffer(); + IETensorBuffer &tensor_info_map = tensor_buffer.getIETensorBuffer(); + + // TODO. consider using multiple tensors later. + if (tensor_info_map.size() != 1) + throw InvalidOperation("Input tensor count not invalid."); + + auto tensor_buffer_iter = tensor_info_map.begin(); + + // Get the meta information corresponding to a given input tensor name. + return _config->getInputMetaMap()[tensor_buffer_iter->first]; +} + +template<typename T> +void ObjectDetection<T>::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector) +{ + LOGI("ENTER"); + + PreprocessConfig config = { false, + metaInfo->colorSpace, + metaInfo->dataType, + metaInfo->getChannel(), + metaInfo->getWidth(), + metaInfo->getHeight() }; + + auto normalization = static_pointer_cast<DecodingNormal>(metaInfo->decodingTypeMap.at(DecodingType::NORMAL)); + if (normalization) { + config.normalize = normalization->use; + config.mean = normalization->mean; + config.std = normalization->std; + } + + auto quantization = + static_pointer_cast<DecodingQuantization>(metaInfo->decodingTypeMap.at(DecodingType::QUANTIZATION)); + if (quantization) { + config.quantize = quantization->use; + config.scale = quantization->scale; + config.zeropoint = quantization->zeropoint; + } + + _preprocess.setConfig(config); + _preprocess.run<T>(mv_src, inputVector); + + LOGI("LEAVE"); +} + +template<typename T> void ObjectDetection<T>::inference(vector<vector<T> > &inputVectors) +{ + LOGI("ENTER"); + + int ret = _inference->run<T>(inputVectors); + if (ret != MEDIA_VISION_ERROR_NONE) + throw InvalidOperation("Fail to run inference"); + + LOGI("LEAVE"); +} + +template<typename T> void ObjectDetection<T>::perform(mv_source_h &mv_src) +{ + shared_ptr<MetaInfo> metaInfo = getInputMetaInfo(); + vector<T> inputVector; + + preprocess(mv_src, metaInfo, inputVector); + + vector<vector<T> > inputVectors = { inputVector }; + inference(inputVectors); +} + +template<typename T> void ObjectDetection<T>::performAsync(ObjectDetectionInput &input) +{ + if (!_async_manager) { + _async_manager = make_unique<AsyncManager<T, ObjectDetectionResult> >([this]() { + AsyncInputQueue<T> inputQueue = _async_manager->popFromInput(); + + inference(inputQueue.inputs); + + ObjectDetectionResult &resultQueue = result(); + + resultQueue.frame_number = inputQueue.frame_number; + _async_manager->pushToOutput(resultQueue); + }); + } + + shared_ptr<MetaInfo> metaInfo = getInputMetaInfo(); + vector<T> inputVector; + + preprocess(input.inference_src, metaInfo, inputVector); + + vector<vector<T> > inputVectors = { inputVector }; + _async_manager->push(inputVectors); +} + +template<typename T> ObjectDetectionResult &ObjectDetection<T>::getOutput() +{ + if (_async_manager) { + if (!_async_manager->isWorking()) + throw InvalidOperation("Object detection has been already destroyed so invalid operation."); + + _current_result = _async_manager->pop(); + } else { + // TODO. Check if inference request is completed or not here. + // If not then throw an exception. + _current_result = result(); + } + + return _current_result; +} + +template<typename T> ObjectDetectionResult &ObjectDetection<T>::getOutputCache() +{ + return _current_result; +} + +template<typename T> void ObjectDetection<T>::getOutputNames(vector<string> &names) +{ + TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer(); + IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer(); + + for (IETensorBuffer::iterator it = ie_tensor_buffer.begin(); it != ie_tensor_buffer.end(); it++) + names.push_back(it->first); +} + +template<typename T> void ObjectDetection<T>::getOutputTensor(string target_name, vector<float> &tensor) +{ + TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer(); + + inference_engine_tensor_buffer *tensor_buffer = tensor_buffer_obj.getTensorBuffer(target_name); + if (!tensor_buffer) + throw InvalidOperation("Fail to get tensor buffer."); + + auto raw_buffer = static_cast<float *>(tensor_buffer->buffer); + + copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor)); +} + +template class ObjectDetection<float>; +template class ObjectDetection<unsigned char>; + +} +} |