diff options
Diffstat (limited to 'inference-engine/src/extension/ext_detectionoutput.cpp')
-rw-r--r-- | inference-engine/src/extension/ext_detectionoutput.cpp | 124 |
1 files changed, 95 insertions, 29 deletions
diff --git a/inference-engine/src/extension/ext_detectionoutput.cpp b/inference-engine/src/extension/ext_detectionoutput.cpp index c68856c74..e55bad9e8 100644 --- a/inference-engine/src/extension/ext_detectionoutput.cpp +++ b/inference-engine/src/extension/ext_detectionoutput.cpp @@ -12,6 +12,7 @@ #include <string> #include <utility> #include <algorithm> +#include "ie_parallel.hpp" namespace InferenceEngine { namespace Extensions { @@ -28,9 +29,9 @@ public: explicit DetectionOutputImpl(const CNNLayer* layer) { try { if (layer->insData.size() != 3) - THROW_IE_EXCEPTION << "Incorrect number of input edges."; + THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << layer->name; if (layer->outData.empty()) - THROW_IE_EXCEPTION << "Incorrect number of output edges."; + THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << layer->name; _num_classes = layer->GetParamAsInt("num_classes"); _background_label_id = layer->GetParamAsInt("background_label_id", 0); @@ -167,29 +168,39 @@ public: for (int n = 0; n < N; ++n) { int detections_total = 0; -#pragma omp parallel for schedule(static) - for (int c = 0; c < _num_classes; ++c) { - if (c == _background_label_id) { - // Ignore background class. - continue; - } + if (!_decrease_label_id) { + // Caffe style + parallel_for(_num_classes, [&](int c) { + if (c != _background_label_id) { // Ignore background class + int *pindices = indices_data + n*_num_classes*_num_priors + c*_num_priors; + int *pbuffer = buffer_data + c*_num_priors; + int *pdetections = detections_data + n*_num_classes + c; + + const float *pconf = reordered_conf_data + n*_num_classes*_num_priors + c*_num_priors; + const float *pboxes; + const float *psizes; + if (_share_location) { + pboxes = decoded_bboxes_data + n*4*_num_priors; + psizes = bbox_sizes_data + n*_num_priors; + } else { + pboxes = decoded_bboxes_data + n*4*_num_classes*_num_priors + c*4*_num_priors; + psizes = bbox_sizes_data + n*_num_classes*_num_priors + c*_num_priors; + } + + nms_cf(pconf, pboxes, psizes, pbuffer, pindices, *pdetections, num_priors_actual[n]); + } + }); + } else { + // MXNet style + int *pindices = indices_data + n*_num_classes*_num_priors; + int *pbuffer = buffer_data; + int *pdetections = detections_data + n*_num_classes; - int *pindices = indices_data + n*_num_classes*_num_priors + c*_num_priors; - int *pbuffer = buffer_data + c*_num_priors; - int *pdetections = detections_data + n*_num_classes + c; - - const float *pconf = reordered_conf_data + n*_num_classes*_num_priors + c*_num_priors; - const float *pboxes; - const float *psizes; - if (_share_location) { - pboxes = decoded_bboxes_data + n*4*_num_priors; - psizes = bbox_sizes_data + n*_num_priors; - } else { - pboxes = decoded_bboxes_data + n*4*_num_classes*_num_priors + c*4*_num_priors; - psizes = bbox_sizes_data + n*_num_classes*_num_priors + c*_num_priors; - } + const float *pconf = reordered_conf_data + n*_num_classes*_num_priors; + const float *pboxes = decoded_bboxes_data + n*4*_num_priors; + const float *psizes = bbox_sizes_data + n*_num_priors; - nms(pconf, pboxes, psizes, pbuffer, pindices, *pdetections, num_priors_actual[n]); + nms_mx(pconf, pboxes, psizes, pbuffer, pindices, pdetections, _num_priors); } for (int c = 0; c < _num_classes; ++c) { @@ -319,8 +330,11 @@ private: void decodeBBoxes(const float *prior_data, const float *loc_data, const float *variance_data, float *decoded_bboxes, float *decoded_bbox_sizes, int* num_priors_actual, int n); - void nms(const float *conf_data, const float *bboxes, const float *sizes, - int *buffer, int *indices, int &detections, int num_priors_actual); + void nms_cf(const float *conf_data, const float *bboxes, const float *sizes, + int *buffer, int *indices, int &detections, int num_priors_actual); + + void nms_mx(const float *conf_data, const float *bboxes, const float *sizes, + int *buffer, int *indices, int *detections, int num_priors_actual); InferenceEngine::Blob::Ptr _decoded_bboxes; InferenceEngine::Blob::Ptr _buffer; @@ -399,8 +413,7 @@ void DetectionOutputImpl::decodeBBoxes(const float *prior_data, } } - #pragma omp parallel for schedule(static) - for (int p = 0; p < num_priors_actual[n]; ++p) { + parallel_for(num_priors_actual[n], [&](int p) { float new_xmin = 0.0f; float new_ymin = 0.0f; float new_xmax = 0.0f; @@ -478,10 +491,10 @@ void DetectionOutputImpl::decodeBBoxes(const float *prior_data, decoded_bboxes[p*4 + 3] = new_ymax; decoded_bbox_sizes[p] = (new_xmax - new_xmin) * (new_ymax - new_ymin); - } + }); } -void DetectionOutputImpl::nms(const float* conf_data, +void DetectionOutputImpl::nms_cf(const float* conf_data, const float* bboxes, const float* sizes, int* buffer, @@ -521,6 +534,59 @@ void DetectionOutputImpl::nms(const float* conf_data, } } +void DetectionOutputImpl::nms_mx(const float* conf_data, + const float* bboxes, + const float* sizes, + int* buffer, + int* indices, + int* detections, + int num_priors_actual) { + int count = 0; + for (int i = 0; i < num_priors_actual; ++i) { + float conf = -1; + int id = 0; + for (int c = 1; c < _num_classes; ++c) { + float temp = conf_data[c*_num_priors + i]; + if (temp > conf) { + conf = temp; + id = c; + } + } + + if (id > 0 && conf >= _confidence_threshold) { + indices[count++] = id*_num_priors + i; + } + } + + int num_output_scores = (_top_k == -1 ? count : std::min<int>(_top_k, count)); + + std::partial_sort_copy(indices, indices + count, + buffer, buffer + num_output_scores, + ConfidenceComparator(conf_data)); + + for (int i = 0; i < num_output_scores; ++i) { + const int idx = buffer[i]; + const int cls = idx/_num_priors; + const int prior = idx%_num_priors; + + int &ndetection = detections[cls]; + int *pindices = indices + cls*_num_priors; + + bool keep = true; + for (int k = 0; k < ndetection; ++k) { + const int kept_idx = pindices[k]; + float overlap = JaccardOverlap(bboxes, sizes, prior, kept_idx); + if (overlap > _nms_threshold) { + keep = false; + break; + } + } + if (keep) { + pindices[ndetection++] = prior; + } + } +} + REG_FACTORY_FOR(ImplFactory<DetectionOutputImpl>, DetectionOutput); } // namespace Cpu |