summaryrefslogtreecommitdiff
path: root/inference-engine/include/ie_preprocess.hpp
blob: bc5d7bd21ec4187a2ecbb4644d3f91dc359535aa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
// Copyright (C) 2018 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0
//

/**
 * @brief This header file provides structures to store info about pre-processing of network inputs (scale, mean image, ...)
 * @file ie_preprocess.hpp
 */
#pragma once

#include "ie_blob.h"
#include <vector>
#include <memory>

namespace InferenceEngine {

/**
 * @brief This structure stores info about pre-processing of network inputs (scale, mean image, ...)
 */
struct PreProcessChannel {
    /** @brief Scale parameter for a channel */
    float stdScale = 1;

    /** @brief Mean value for a channel */
    float meanValue = 0;

    /** @brief Mean data for a channel */
    Blob::Ptr meanData;

    /** @brief Smart pointer to an instance */
    using Ptr = std::shared_ptr<PreProcessChannel>;
};

/**
 * @brief Defines available types of mean
 */
enum MeanVariant {
    MEAN_IMAGE, /**< mean value is specified for each input pixel */
    MEAN_VALUE, /**< mean value is specified for each input channel */
    NONE,       /**< no mean value specified */
};

/**
 * @enum ResizeAlgorithm
 * @brief Represents the list of supported resize algorithms.
 */
enum ResizeAlgorithm {
    NO_RESIZE = 0,
    RESIZE_BILINEAR,
    RESIZE_AREA
};

/**
 * @brief This class stores pre-process information for the input
 */
class PreProcessInfo {
    // Channel data
    std::vector<PreProcessChannel::Ptr> _channelsInfo;
    MeanVariant _variant = NONE;

    // Resize Algorithm to be applied for input before inference if needed.
    ResizeAlgorithm _resizeAlg = NO_RESIZE;

public:
    /**
     * @brief Overloaded [] operator to safely get the channel by an index. 
     * Throws an exception if channels are empty.
     * @param index Index of the channel to get
     * @return The pre-process channel instance
     */
    PreProcessChannel::Ptr &operator[](size_t index) {
        if (_channelsInfo.empty()) {
            THROW_IE_EXCEPTION << "accessing pre-process when nothing was set.";
        }
        if (index >= _channelsInfo.size()) {
            THROW_IE_EXCEPTION << "pre process index " << index << " is out of bounds.";
        }
        return _channelsInfo[index];
    }

    /**
     * @brief operator [] to safely get the channel preprocessing information by index.
     * Throws exception if channels are empty or index is out of border
     *
     * @param index Index of the channel to get
     * @return The const preprocess channel instance
     */
    const PreProcessChannel::Ptr &operator[](size_t index) const {
        if (_channelsInfo.empty()) {
            THROW_IE_EXCEPTION << "accessing pre-process when nothing was set.";
        }
        if (index >= _channelsInfo.size()) {
            THROW_IE_EXCEPTION << "pre process index " << index << " is out of bounds.";
        }
        return _channelsInfo[index];
    }

    /**
     * @brief Returns a number of channels to preprocess
     * @return The number of channels
     */
    size_t getNumberOfChannels() const {
        return _channelsInfo.size();
    }

    /**
     * @brief Initializes with given number of channels
     * @param numberOfChannels Number of channels to initialize
     */
    void init(const size_t numberOfChannels) {
        _channelsInfo.resize(numberOfChannels);
        for (auto &channelInfo : _channelsInfo) {
            channelInfo = std::make_shared<PreProcessChannel>();
        }
    }

    /**
     * @brief Sets mean image values if operation is applicable.
     * Also sets the mean type to MEAN_IMAGE for all channels
     * @param meanImage Blob with a mean image
     */
    void setMeanImage(const Blob::Ptr &meanImage) {
        if (meanImage.get() == nullptr) {
            THROW_IE_EXCEPTION << "Failed to set invalid mean image: nullptr";
        } else if (meanImage.get()->dims().size() != 3) {
            THROW_IE_EXCEPTION << "Failed to set invalid mean image: number of dimensions != 3";
        } else if (meanImage.get()->dims()[2] != getNumberOfChannels()) {
            THROW_IE_EXCEPTION << "Failed to set invalid mean image: number of channels != "
                               << getNumberOfChannels();
        } else if (meanImage.get()->layout() != Layout::CHW) {
            THROW_IE_EXCEPTION << "Mean image layout should be CHW";
        }
        _variant = MEAN_IMAGE;
    }

    /**
     * @brief Sets mean image values if operation is applicable.
     * Also sets the mean type to MEAN_IMAGE for a particular channel
     * @param meanImage Blob with a mean image
     * @param channel Index of a particular channel
     */
    void setMeanImageForChannel(const Blob::Ptr &meanImage, const size_t channel) {
        if (meanImage.get() == nullptr) {
            THROW_IE_EXCEPTION << "Failed to set invalid mean image for channel: nullptr";
        } else if (meanImage.get()->dims().size() != 2) {
            THROW_IE_EXCEPTION << "Failed to set invalid mean image for channel: number of dimensions != 2";
        } else if (channel >= _channelsInfo.size()) {
            THROW_IE_EXCEPTION << "Channel " << channel << " exceed number of PreProcess channels: "
                               << _channelsInfo.size();
        }
        _variant = MEAN_IMAGE;
        _channelsInfo[channel]->meanData = meanImage;
    }

    /**
     * @brief Sets a type of mean operation
     * @param variant Type of mean operation to set
     */
    void setVariant(const MeanVariant &variant) {
        _variant = variant;
    }

    /**
     * @brief Gets a type of mean operation
     * @return The type of mean operation
     */
    MeanVariant getMeanVariant() const {
        return _variant;
    }

    /**
     * @brief Sets resize algorithm to be used during pre-processing.
     * @param alg Resize algorithm.
     */
    void setResizeAlgorithm(const ResizeAlgorithm &alg) {
        _resizeAlg = alg;
    }

    /**
     * @brief Gets preconfigured resize algorithm.
     * @return Resize algorithm.
     */
    ResizeAlgorithm getResizeAlgorithm() const {
        return _resizeAlg;
    }
};
}  // namespace InferenceEngine