summaryrefslogtreecommitdiff
path: root/mv_machine_learning/object_detection/src/ObjectDetection.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mv_machine_learning/object_detection/src/ObjectDetection.cpp')
-rw-r--r--mv_machine_learning/object_detection/src/ObjectDetection.cpp342
1 files changed, 342 insertions, 0 deletions
diff --git a/mv_machine_learning/object_detection/src/ObjectDetection.cpp b/mv_machine_learning/object_detection/src/ObjectDetection.cpp
new file mode 100644
index 00000000..ea31f322
--- /dev/null
+++ b/mv_machine_learning/object_detection/src/ObjectDetection.cpp
@@ -0,0 +1,342 @@
+/**
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <string.h>
+#include <fstream>
+#include <map>
+#include <memory>
+
+#include "MvMlException.h"
+#include "common.h"
+#include "mv_object_detection_config.h"
+#include "ObjectDetection.h"
+
+using namespace std;
+using namespace mediavision::inference;
+using namespace MediaVision::Common;
+using namespace mediavision::common;
+using namespace mediavision::machine_learning::exception;
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T>
+ObjectDetection<T>::ObjectDetection(ObjectDetectionTaskType task_type, shared_ptr<Config> config)
+ : _task_type(task_type), _config(config)
+{
+ _inference = make_unique<Inference>();
+}
+
+template<typename T> void ObjectDetection<T>::preDestroy()
+{
+ if (!_async_manager)
+ return;
+
+ _async_manager->stop();
+}
+
+template<typename T> ObjectDetectionTaskType ObjectDetection<T>::getTaskType()
+{
+ return _task_type;
+}
+
+template<typename T> void ObjectDetection<T>::getEngineList()
+{
+ for (auto idx = MV_INFERENCE_BACKEND_NONE + 1; idx < MV_INFERENCE_BACKEND_MAX; ++idx) {
+ auto backend = _inference->getSupportedInferenceBackend(idx);
+ // TODO. we need to describe what inference engines are supported by each Task API,
+ // and based on it, below inference engine types should be checked
+ // if a given type is supported by this Task API later. As of now, tflite only.
+ if (backend.second == true && backend.first.compare("tflite") == 0)
+ _valid_backends.push_back(backend.first);
+ }
+}
+
+template<typename T> void ObjectDetection<T>::getDeviceList(const string &engine_type)
+{
+ // TODO. add device types available for a given engine type later.
+ // In default, cpu and gpu only.
+ _valid_devices.push_back("cpu");
+ _valid_devices.push_back("gpu");
+}
+
+template<typename T> void ObjectDetection<T>::setEngineInfo(std::string engine_type_name, std::string device_type_name)
+{
+ if (engine_type_name.empty() || device_type_name.empty())
+ throw InvalidParameter("Invalid engine info.");
+
+ transform(engine_type_name.begin(), engine_type_name.end(), engine_type_name.begin(), ::toupper);
+ transform(device_type_name.begin(), device_type_name.end(), device_type_name.begin(), ::toupper);
+
+ int engine_type = GetBackendType(engine_type_name);
+ int device_type = GetDeviceType(device_type_name);
+
+ if (engine_type == MEDIA_VISION_ERROR_INVALID_PARAMETER || device_type == MEDIA_VISION_ERROR_INVALID_PARAMETER)
+ throw InvalidParameter("backend or target device type not found.");
+
+ _config->setBackendType(engine_type);
+ _config->setTargetDeviceType(device_type);
+
+ LOGI("Engine type : %s => %d, Device type : %s => %d", engine_type_name.c_str(), engine_type,
+ device_type_name.c_str(), device_type);
+}
+
+template<typename T> unsigned int ObjectDetection<T>::getNumberOfEngines()
+{
+ if (!_valid_backends.empty()) {
+ return _valid_backends.size();
+ }
+
+ getEngineList();
+ return _valid_backends.size();
+}
+
+template<typename T> const string &ObjectDetection<T>::getEngineType(unsigned int engine_index)
+{
+ if (!_valid_backends.empty()) {
+ if (_valid_backends.size() <= engine_index)
+ throw InvalidParameter("Invalid engine index.");
+
+ return _valid_backends[engine_index];
+ }
+
+ getEngineList();
+
+ if (_valid_backends.size() <= engine_index)
+ throw InvalidParameter("Invalid engine index.");
+
+ return _valid_backends[engine_index];
+}
+
+template<typename T> unsigned int ObjectDetection<T>::getNumberOfDevices(const string &engine_type)
+{
+ if (!_valid_devices.empty()) {
+ return _valid_devices.size();
+ }
+
+ getDeviceList(engine_type);
+ return _valid_devices.size();
+}
+
+template<typename T>
+const string &ObjectDetection<T>::getDeviceType(const string &engine_type, unsigned int device_index)
+{
+ if (!_valid_devices.empty()) {
+ if (_valid_devices.size() <= device_index)
+ throw InvalidParameter("Invalid device index.");
+
+ return _valid_devices[device_index];
+ }
+
+ getDeviceList(engine_type);
+
+ if (_valid_devices.size() <= device_index)
+ throw InvalidParameter("Invalid device index.");
+
+ return _valid_devices[device_index];
+}
+
+template<typename T> void ObjectDetection<T>::loadLabel()
+{
+ if (_config->getLabelFilePath().empty())
+ return;
+
+ ifstream readFile;
+
+ _labels.clear();
+ readFile.open(_config->getLabelFilePath().c_str());
+
+ if (readFile.fail())
+ throw InvalidOperation("Fail to open " + _config->getLabelFilePath() + " file.");
+
+ string line;
+
+ while (getline(readFile, line))
+ _labels.push_back(line);
+
+ readFile.close();
+}
+
+template<typename T> void ObjectDetection<T>::configure()
+{
+ loadLabel();
+
+ int ret = _inference->bind(_config->getBackendType(), _config->getTargetDeviceType());
+ if (ret != MEDIA_VISION_ERROR_NONE)
+ throw InvalidOperation("Fail to bind a backend engine.");
+}
+
+template<typename T> void ObjectDetection<T>::prepare()
+{
+ int ret = _inference->configureInputMetaInfo(_config->getInputMetaMap());
+ if (ret != MEDIA_VISION_ERROR_NONE)
+ throw InvalidOperation("Fail to configure input tensor info from meta file.");
+
+ ret = _inference->configureOutputMetaInfo(_config->getOutputMetaMap());
+ if (ret != MEDIA_VISION_ERROR_NONE)
+ throw InvalidOperation("Fail to configure output tensor info from meta file.");
+
+ _inference->configureModelFiles("", _config->getModelFilePath(), "");
+
+ // Request to load model files to a backend engine.
+ ret = _inference->load();
+ if (ret != MEDIA_VISION_ERROR_NONE)
+ throw InvalidOperation("Fail to load model files.");
+}
+
+template<typename T> shared_ptr<MetaInfo> ObjectDetection<T>::getInputMetaInfo()
+{
+ TensorBuffer &tensor_buffer = _inference->getInputTensorBuffer();
+ IETensorBuffer &tensor_info_map = tensor_buffer.getIETensorBuffer();
+
+ // TODO. consider using multiple tensors later.
+ if (tensor_info_map.size() != 1)
+ throw InvalidOperation("Input tensor count not invalid.");
+
+ auto tensor_buffer_iter = tensor_info_map.begin();
+
+ // Get the meta information corresponding to a given input tensor name.
+ return _config->getInputMetaMap()[tensor_buffer_iter->first];
+}
+
+template<typename T>
+void ObjectDetection<T>::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
+{
+ LOGI("ENTER");
+
+ PreprocessConfig config = { false,
+ metaInfo->colorSpace,
+ metaInfo->dataType,
+ metaInfo->getChannel(),
+ metaInfo->getWidth(),
+ metaInfo->getHeight() };
+
+ auto normalization = static_pointer_cast<DecodingNormal>(metaInfo->decodingTypeMap.at(DecodingType::NORMAL));
+ if (normalization) {
+ config.normalize = normalization->use;
+ config.mean = normalization->mean;
+ config.std = normalization->std;
+ }
+
+ auto quantization =
+ static_pointer_cast<DecodingQuantization>(metaInfo->decodingTypeMap.at(DecodingType::QUANTIZATION));
+ if (quantization) {
+ config.quantize = quantization->use;
+ config.scale = quantization->scale;
+ config.zeropoint = quantization->zeropoint;
+ }
+
+ _preprocess.setConfig(config);
+ _preprocess.run<T>(mv_src, inputVector);
+
+ LOGI("LEAVE");
+}
+
+template<typename T> void ObjectDetection<T>::inference(vector<vector<T> > &inputVectors)
+{
+ LOGI("ENTER");
+
+ int ret = _inference->run<T>(inputVectors);
+ if (ret != MEDIA_VISION_ERROR_NONE)
+ throw InvalidOperation("Fail to run inference");
+
+ LOGI("LEAVE");
+}
+
+template<typename T> void ObjectDetection<T>::perform(mv_source_h &mv_src)
+{
+ shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
+ vector<T> inputVector;
+
+ preprocess(mv_src, metaInfo, inputVector);
+
+ vector<vector<T> > inputVectors = { inputVector };
+ inference(inputVectors);
+}
+
+template<typename T> void ObjectDetection<T>::performAsync(ObjectDetectionInput &input)
+{
+ if (!_async_manager) {
+ _async_manager = make_unique<AsyncManager<T, ObjectDetectionResult> >([this]() {
+ AsyncInputQueue<T> inputQueue = _async_manager->popFromInput();
+
+ inference(inputQueue.inputs);
+
+ ObjectDetectionResult &resultQueue = result();
+
+ resultQueue.frame_number = inputQueue.frame_number;
+ _async_manager->pushToOutput(resultQueue);
+ });
+ }
+
+ shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
+ vector<T> inputVector;
+
+ preprocess(input.inference_src, metaInfo, inputVector);
+
+ vector<vector<T> > inputVectors = { inputVector };
+ _async_manager->push(inputVectors);
+}
+
+template<typename T> ObjectDetectionResult &ObjectDetection<T>::getOutput()
+{
+ if (_async_manager) {
+ if (!_async_manager->isWorking())
+ throw InvalidOperation("Object detection has been already destroyed so invalid operation.");
+
+ _current_result = _async_manager->pop();
+ } else {
+ // TODO. Check if inference request is completed or not here.
+ // If not then throw an exception.
+ _current_result = result();
+ }
+
+ return _current_result;
+}
+
+template<typename T> ObjectDetectionResult &ObjectDetection<T>::getOutputCache()
+{
+ return _current_result;
+}
+
+template<typename T> void ObjectDetection<T>::getOutputNames(vector<string> &names)
+{
+ TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
+ IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
+
+ for (IETensorBuffer::iterator it = ie_tensor_buffer.begin(); it != ie_tensor_buffer.end(); it++)
+ names.push_back(it->first);
+}
+
+template<typename T> void ObjectDetection<T>::getOutputTensor(string target_name, vector<float> &tensor)
+{
+ TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
+
+ inference_engine_tensor_buffer *tensor_buffer = tensor_buffer_obj.getTensorBuffer(target_name);
+ if (!tensor_buffer)
+ throw InvalidOperation("Fail to get tensor buffer.");
+
+ auto raw_buffer = static_cast<float *>(tensor_buffer->buffer);
+
+ copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor));
+}
+
+template class ObjectDetection<float>;
+template class ObjectDetection<unsigned char>;
+
+}
+}