summaryrefslogtreecommitdiff
path: root/inference-engine/samples/validation_app/Processor.cpp
blob: 734271187e6f0cb083372d25e4e03c7cbb155f8d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
// Copyright (C) 2018 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0
//

#include <string>
#include <algorithm>

#include <samples/common.hpp>

#include "Processor.hpp"

using namespace InferenceEngine;

Processor::Processor(const std::string& flags_m, const std::string& flags_d, const std::string& flags_i, int flags_b,
        InferencePlugin plugin, CsvDumper& dumper, const std::string& approach, PreprocessingOptions preprocessingOptions)

    : targetDevice(flags_d), modelFileName(flags_m), imagesPath(flags_i), batch(flags_b),
      plugin(plugin), dumper(dumper), approach(approach), preprocessingOptions(preprocessingOptions) {

    // --------------------Load network (Generated xml/bin files)-------------------------------------------
    slog::info << "Loading network files" << slog::endl;

    loadDuration = getDurationOf([&]() {
        /** Read network model **/
        networkReader.ReadNetwork(modelFileName);
        if (!networkReader.isParseSuccess()) THROW_IE_EXCEPTION << "cannot load a failed Model";

        /** Extract model name and load weights **/
        std::string binFileName = fileNameNoExt(modelFileName) + ".bin";
        networkReader.ReadWeights(binFileName.c_str());
    });
    // -----------------------------------------------------------------------------------------------------

    // -----------------------------Prepare input blobs-----------------------------------------------------
    slog::info << "Preparing input blobs" << slog::endl;

    /** Taking information about all topology inputs **/
    inputInfo = InputsDataMap(networkReader.getNetwork().getInputsInfo());

    /** Stores all input blobs data **/

    // TODO Check if it's necessary
    if (!targetDevice.empty()) {
        networkReader.getNetwork().setTargetDevice(getDeviceFromStr(targetDevice));
    }

    if (batch == 0) {
        // Zero means "take batch value from the IR"
        batch = networkReader.getNetwork().getBatchSize();
    } else {
        // Not zero means "use the specified value"
        networkReader.getNetwork().setBatchSize(batch);
    }

    if (inputInfo.size() != 1) {
        THROW_IE_EXCEPTION << "This app accepts networks having only one input";
    }

    for (auto & item : inputInfo) {
        inputDims = item.second->getDims();
        slog::info << "Batch size is " << std::to_string(networkReader.getNetwork().getBatchSize()) << slog::endl;
    }

    outInfo = networkReader.getNetwork().getOutputsInfo();
    DataPtr outData = outInfo.begin()->second;

    // Set the precision of output data provided by the user, should be called before load of the network to the plugin
    if (!outData) {
        throw std::logic_error("output data pointer is not valid");
    }
    outData->setPrecision(Precision::FP32);
    if (outInfo.size() != 1) {
        THROW_IE_EXCEPTION << "This app accepts networks having only one output";
    }
    if (!outData) {
        THROW_IE_EXCEPTION << "The network output info is not valid";
    }

    outputDims = outData->dims;

    // Load model to plugin and create an inference request

    ExecutableNetwork executable_network = plugin.LoadNetwork(networkReader.getNetwork(), {});
    inferRequest = executable_network.CreateInferRequest();
}

double Processor::Infer(ConsoleProgress& progress, int filesWatched, InferenceMetrics& im) {
    ResponseDesc dsc;

    // InferencePlugin plugin(enginePtr);

    // Infer model
    double time = getDurationOf([&]() {
        inferRequest.Infer();
    });

    im.maxDuration = std::min(im.maxDuration, time);
    im.minDuration = std::max(im.minDuration, time);
    im.totalTime += time;
    im.nRuns++;

    progress.addProgress(filesWatched);

    return time;
}