summaryrefslogtreecommitdiff
path: root/inference-engine/samples/validation_app/YOLOObjectDetectionProcessor.hpp
blob: fe9dad91b8f0d30911aa5aec7faad092e5a953d7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Copyright (C) 2018 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <iostream>
#include <limits>
#include <map>
#include <memory>
#include <string>
#include <list>
#include <vector>
#include <algorithm>

using namespace std;

class YOLOObjectDetectionProcessor : public ObjectDetectionProcessor {
private:
    /**
     * \brief This function analyses the YOLO net output for a single class
     * @param net_out - The output data
     * @param class_num - The class number
     * @return a list of found boxes
     */
    std::vector<DetectedObject> yoloNetParseOutput(const float *net_out, int class_num) {
        float threshold = 0.2f;         // The confidence threshold
        int C = 20;                     // classes
        int B = 2;                      // bounding boxes
        int S = 7;                      // cell size

        std::vector<DetectedObject> boxes;
        std::vector<DetectedObject> boxes_result;
        int SS = S * S;                 // number of grid cells 7*7 = 49
        // First 980 values corresponds to probabilities for each of the 20 classes for each grid cell.
        // These probabilities are conditioned on objects being present in each grid cell.
        int prob_size = SS * C;         // class probabilities 49 * 20 = 980
        // The next 98 values are confidence scores for 2 bounding boxes predicted by each grid cells.
        int conf_size = SS * B;         // 49*2 = 98 confidences for each grid cell

        const float *probs = &net_out[0];
        const float *confs = &net_out[prob_size];
        const float *cords = &net_out[prob_size + conf_size];     // 98*4 = 392 coords x, y, w, h

        for (int grid = 0; grid < SS; grid++) {
            int row = grid / S;
            int col = grid % S;
            for (int b = 0; b < B; b++) {
                int index = grid * B + b;
                int p_index = SS * C + grid * B + b;
                float scale = net_out[p_index];
                int box_index = SS * (C + B) + (grid * B + b) * 4;
                int objectType = class_num;

                float conf = confs[(grid * B + b)];
                float xc = (cords[(grid * B + b) * 4 + 0] + col) / S;
                float yc = (cords[(grid * B + b) * 4 + 1] + row) / S;
                float w = pow(cords[(grid * B + b) * 4 + 2], 2);
                float h = pow(cords[(grid * B + b) * 4 + 3], 2);
                int class_index = grid * C;
                float prob = probs[grid * C + class_num] * conf;

                DetectedObject bx(objectType, xc - w / 2, yc - h / 2, xc + w / 2,
                        yc + h / 2, prob);

                if (prob >= threshold) {
                    boxes.push_back(bx);
                }
            }
        }

        // Sorting the higher probabilities to the top
        sort(boxes.begin(), boxes.end(),
                [](const DetectedObject & a, const DetectedObject & b) -> bool {
                    return a.prob > b.prob;
                });

        // Filtering out overlapping boxes
        std::vector<bool> overlapped(boxes.size(), false);
        for (int i = 0; i < boxes.size(); i++) {
            if (overlapped[i])
                continue;

            DetectedObject box_i = boxes[i];
            for (int j = i + 1; j < boxes.size(); j++) {
                DetectedObject box_j = boxes[j];
                if (DetectedObject::ioU(box_i, box_j) >= 0.4) {
                    overlapped[j] = true;
                }
            }
        }

        for (int i = 0; i < boxes.size(); i++) {
            if (boxes[i].prob > 0.0f) {
                boxes_result.push_back(boxes[i]);
            }
        }
        return boxes_result;
    }

protected:
    std::map<std::string, std::list<DetectedObject>> processResult(std::vector<std::string> files) {
        std::map<std::string, std::list<DetectedObject>> detectedObjects;

        std::string firstOutputName = this->outInfo.begin()->first;
        const auto detectionOutArray = inferRequest.GetBlob(firstOutputName);
        const float *box = detectionOutArray->buffer().as<float*>();

        std::string file = *files.begin();
        for (int c = 0; c < 20; c++) {
            std::vector<DetectedObject> result = yoloNetParseOutput(box, c);
            detectedObjects[file].insert(detectedObjects[file].end(), result.begin(), result.end());
        }

        return detectedObjects;
    }

public:
    YOLOObjectDetectionProcessor(const std::string& flags_m, const std::string& flags_d, const std::string& flags_i, const std::string& subdir, int flags_b,
            double threshold,
            InferencePlugin plugin, CsvDumper& dumper,
            const std::string& flags_a, const std::string& classes_list_file) :

                ObjectDetectionProcessor(flags_m, flags_d, flags_i, subdir, flags_b, threshold,
                        plugin, dumper, flags_a, classes_list_file, PreprocessingOptions(true, ResizeCropPolicy::Resize), false) { }
};