summaryrefslogtreecommitdiff
path: root/inference-engine/samples/validation_app/ClassificationProcessor.cpp
blob: 78e6adcbec359b7959885ff16609adf3c0490ecb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// Copyright (C) 2018 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0
//

#include <string>
#include <vector>
#include <memory>

#include "ClassificationProcessor.hpp"
#include "Processor.hpp"

using InferenceEngine::details::InferenceEngineException;

ClassificationProcessor::ClassificationProcessor(const std::string& flags_m, const std::string& flags_d, const std::string& flags_i, int flags_b,
        InferencePlugin plugin, CsvDumper& dumper, const std::string& flags_l,
        PreprocessingOptions preprocessingOptions, bool zeroBackground)
    : Processor(flags_m, flags_d, flags_i, flags_b, plugin, dumper, "Classification network", preprocessingOptions), zeroBackground(zeroBackground) {

    // Change path to labels file if necessary
    if (flags_l.empty()) {
        labelFileName = fileNameNoExt(modelFileName) + ".labels";
    } else {
        labelFileName = flags_l;
    }
}

ClassificationProcessor::ClassificationProcessor(const std::string& flags_m, const std::string& flags_d, const std::string& flags_i, int flags_b,
        InferencePlugin plugin, CsvDumper& dumper, const std::string& flags_l, bool zeroBackground)
    : ClassificationProcessor(flags_m, flags_d, flags_i, flags_b, plugin, dumper, flags_l,
            PreprocessingOptions(false, ResizeCropPolicy::ResizeThenCrop, 256, 256), zeroBackground) {
}

std::shared_ptr<Processor::InferenceMetrics> ClassificationProcessor::Process() {
     slog::info << "Collecting labels" << slog::endl;
     ClassificationSetGenerator generator;
     // try {
     //     generator.readLabels(labelFileName);
     // } catch (InferenceEngine::details::InferenceEngineException& ex) {
     //     slog::warn << "Can't read labels file " << labelFileName << slog::endl;
     // }

     auto validationMap = generator.getValidationMap(imagesPath);
     ImageDecoder decoder;

     // ----------------------------Do inference-------------------------------------------------------------
     slog::info << "Starting inference" << slog::endl;

     std::vector<int> expected(batch);
     std::vector<std::string> files(batch);

     ConsoleProgress progress(validationMap.size());

     ClassificationInferenceMetrics im;

     std::string firstInputName = this->inputInfo.begin()->first;
     std::string firstOutputName = this->outInfo.begin()->first;
     auto firstInputBlob = inferRequest.GetBlob(firstInputName);
     auto firstOutputBlob = inferRequest.GetBlob(firstOutputName);

     auto iter = validationMap.begin();
     while (iter != validationMap.end()) {
         int b = 0;
         int filesWatched = 0;
         for (; b < batch && iter != validationMap.end(); b++, iter++, filesWatched++) {
             expected[b] = iter->first;
             try {
                 decoder.insertIntoBlob(iter->second, b, *firstInputBlob, preprocessingOptions);
                 files[b] = iter->second;
             } catch (const InferenceEngineException& iex) {
                 slog::warn << "Can't read file " << iter->second << slog::endl;
                 // Could be some non-image file in directory
                 b--;
                 continue;
             }
         }

         Infer(progress, filesWatched, im);

         std::vector<unsigned> results;
         auto firstOutputData = firstOutputBlob->buffer().as<PrecisionTrait<Precision::FP32>::value_type*>();
         InferenceEngine::TopResults(TOP_COUNT, *firstOutputBlob, results);

         for (int i = 0; i < b; i++) {
             int expc = expected[i];
             if (zeroBackground) expc++;

             bool top1Scored = (results[0 + TOP_COUNT * i] == expc);
             dumper << "\"" + files[i] + "\"" << top1Scored;
             if (top1Scored) im.top1Result++;
             for (int j = 0; j < TOP_COUNT; j++) {
                 unsigned classId = results[j + TOP_COUNT * i];
                 if (classId == expc) {
                     im.topCountResult++;
                 }
                 dumper << classId << firstOutputData[classId + i * (firstOutputBlob->size() / batch)];
             }
             dumper.endLine();
             im.total++;
         }
     }
     progress.finish();

     return std::shared_ptr<Processor::InferenceMetrics>(new ClassificationInferenceMetrics(im));
}

void ClassificationProcessor::Report(const Processor::InferenceMetrics& im) {
    Processor::Report(im);
    if (im.nRuns > 0) {
        const ClassificationInferenceMetrics& cim = dynamic_cast<const ClassificationInferenceMetrics&>(im);

        cout << "Top1 accuracy: " << OUTPUT_FLOATING(100.0 * cim.top1Result / cim.total) << "% (" << cim.top1Result << " of "
                << cim.total << " images were detected correctly, top class is correct)" << "\n";
        cout << "Top5 accuracy: " << OUTPUT_FLOATING(100.0 * cim.topCountResult / cim.total) << "% (" << cim.topCountResult << " of "
            << cim.total << " images were detected correctly, top five classes contain required class)" << "\n";
    }
}