summaryrefslogtreecommitdiff
path: root/inference-engine/src/extension/ext_detectionoutput.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/src/extension/ext_detectionoutput.cpp')
-rw-r--r--inference-engine/src/extension/ext_detectionoutput.cpp124
1 files changed, 95 insertions, 29 deletions
diff --git a/inference-engine/src/extension/ext_detectionoutput.cpp b/inference-engine/src/extension/ext_detectionoutput.cpp
index c68856c74..e55bad9e8 100644
--- a/inference-engine/src/extension/ext_detectionoutput.cpp
+++ b/inference-engine/src/extension/ext_detectionoutput.cpp
@@ -12,6 +12,7 @@
#include <string>
#include <utility>
#include <algorithm>
+#include "ie_parallel.hpp"
namespace InferenceEngine {
namespace Extensions {
@@ -28,9 +29,9 @@ public:
explicit DetectionOutputImpl(const CNNLayer* layer) {
try {
if (layer->insData.size() != 3)
- THROW_IE_EXCEPTION << "Incorrect number of input edges.";
+ THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << layer->name;
if (layer->outData.empty())
- THROW_IE_EXCEPTION << "Incorrect number of output edges.";
+ THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << layer->name;
_num_classes = layer->GetParamAsInt("num_classes");
_background_label_id = layer->GetParamAsInt("background_label_id", 0);
@@ -167,29 +168,39 @@ public:
for (int n = 0; n < N; ++n) {
int detections_total = 0;
-#pragma omp parallel for schedule(static)
- for (int c = 0; c < _num_classes; ++c) {
- if (c == _background_label_id) {
- // Ignore background class.
- continue;
- }
+ if (!_decrease_label_id) {
+ // Caffe style
+ parallel_for(_num_classes, [&](int c) {
+ if (c != _background_label_id) { // Ignore background class
+ int *pindices = indices_data + n*_num_classes*_num_priors + c*_num_priors;
+ int *pbuffer = buffer_data + c*_num_priors;
+ int *pdetections = detections_data + n*_num_classes + c;
+
+ const float *pconf = reordered_conf_data + n*_num_classes*_num_priors + c*_num_priors;
+ const float *pboxes;
+ const float *psizes;
+ if (_share_location) {
+ pboxes = decoded_bboxes_data + n*4*_num_priors;
+ psizes = bbox_sizes_data + n*_num_priors;
+ } else {
+ pboxes = decoded_bboxes_data + n*4*_num_classes*_num_priors + c*4*_num_priors;
+ psizes = bbox_sizes_data + n*_num_classes*_num_priors + c*_num_priors;
+ }
+
+ nms_cf(pconf, pboxes, psizes, pbuffer, pindices, *pdetections, num_priors_actual[n]);
+ }
+ });
+ } else {
+ // MXNet style
+ int *pindices = indices_data + n*_num_classes*_num_priors;
+ int *pbuffer = buffer_data;
+ int *pdetections = detections_data + n*_num_classes;
- int *pindices = indices_data + n*_num_classes*_num_priors + c*_num_priors;
- int *pbuffer = buffer_data + c*_num_priors;
- int *pdetections = detections_data + n*_num_classes + c;
-
- const float *pconf = reordered_conf_data + n*_num_classes*_num_priors + c*_num_priors;
- const float *pboxes;
- const float *psizes;
- if (_share_location) {
- pboxes = decoded_bboxes_data + n*4*_num_priors;
- psizes = bbox_sizes_data + n*_num_priors;
- } else {
- pboxes = decoded_bboxes_data + n*4*_num_classes*_num_priors + c*4*_num_priors;
- psizes = bbox_sizes_data + n*_num_classes*_num_priors + c*_num_priors;
- }
+ const float *pconf = reordered_conf_data + n*_num_classes*_num_priors;
+ const float *pboxes = decoded_bboxes_data + n*4*_num_priors;
+ const float *psizes = bbox_sizes_data + n*_num_priors;
- nms(pconf, pboxes, psizes, pbuffer, pindices, *pdetections, num_priors_actual[n]);
+ nms_mx(pconf, pboxes, psizes, pbuffer, pindices, pdetections, _num_priors);
}
for (int c = 0; c < _num_classes; ++c) {
@@ -319,8 +330,11 @@ private:
void decodeBBoxes(const float *prior_data, const float *loc_data, const float *variance_data,
float *decoded_bboxes, float *decoded_bbox_sizes, int* num_priors_actual, int n);
- void nms(const float *conf_data, const float *bboxes, const float *sizes,
- int *buffer, int *indices, int &detections, int num_priors_actual);
+ void nms_cf(const float *conf_data, const float *bboxes, const float *sizes,
+ int *buffer, int *indices, int &detections, int num_priors_actual);
+
+ void nms_mx(const float *conf_data, const float *bboxes, const float *sizes,
+ int *buffer, int *indices, int *detections, int num_priors_actual);
InferenceEngine::Blob::Ptr _decoded_bboxes;
InferenceEngine::Blob::Ptr _buffer;
@@ -399,8 +413,7 @@ void DetectionOutputImpl::decodeBBoxes(const float *prior_data,
}
}
- #pragma omp parallel for schedule(static)
- for (int p = 0; p < num_priors_actual[n]; ++p) {
+ parallel_for(num_priors_actual[n], [&](int p) {
float new_xmin = 0.0f;
float new_ymin = 0.0f;
float new_xmax = 0.0f;
@@ -478,10 +491,10 @@ void DetectionOutputImpl::decodeBBoxes(const float *prior_data,
decoded_bboxes[p*4 + 3] = new_ymax;
decoded_bbox_sizes[p] = (new_xmax - new_xmin) * (new_ymax - new_ymin);
- }
+ });
}
-void DetectionOutputImpl::nms(const float* conf_data,
+void DetectionOutputImpl::nms_cf(const float* conf_data,
const float* bboxes,
const float* sizes,
int* buffer,
@@ -521,6 +534,59 @@ void DetectionOutputImpl::nms(const float* conf_data,
}
}
+void DetectionOutputImpl::nms_mx(const float* conf_data,
+ const float* bboxes,
+ const float* sizes,
+ int* buffer,
+ int* indices,
+ int* detections,
+ int num_priors_actual) {
+ int count = 0;
+ for (int i = 0; i < num_priors_actual; ++i) {
+ float conf = -1;
+ int id = 0;
+ for (int c = 1; c < _num_classes; ++c) {
+ float temp = conf_data[c*_num_priors + i];
+ if (temp > conf) {
+ conf = temp;
+ id = c;
+ }
+ }
+
+ if (id > 0 && conf >= _confidence_threshold) {
+ indices[count++] = id*_num_priors + i;
+ }
+ }
+
+ int num_output_scores = (_top_k == -1 ? count : std::min<int>(_top_k, count));
+
+ std::partial_sort_copy(indices, indices + count,
+ buffer, buffer + num_output_scores,
+ ConfidenceComparator(conf_data));
+
+ for (int i = 0; i < num_output_scores; ++i) {
+ const int idx = buffer[i];
+ const int cls = idx/_num_priors;
+ const int prior = idx%_num_priors;
+
+ int &ndetection = detections[cls];
+ int *pindices = indices + cls*_num_priors;
+
+ bool keep = true;
+ for (int k = 0; k < ndetection; ++k) {
+ const int kept_idx = pindices[k];
+ float overlap = JaccardOverlap(bboxes, sizes, prior, kept_idx);
+ if (overlap > _nms_threshold) {
+ keep = false;
+ break;
+ }
+ }
+ if (keep) {
+ pindices[ndetection++] = prior;
+ }
+ }
+}
+
REG_FACTORY_FOR(ImplFactory<DetectionOutputImpl>, DetectionOutput);
} // namespace Cpu