summaryrefslogtreecommitdiff
path: root/inference-engine/tests/helpers/test_assertions.hpp
blob: a8ae3660815a78ea8d75a49e1d8f87f633ec4428 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// Copyright (C) 2018 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>
#include "inference_engine.hpp"

#define ASSERT_BLOB_EQ(lhs, rhs) \
compare_blob(lhs,rhs)

#define ASSERT_DIMS_EQ(lhs, rhs) \
compare_dims(lhs,rhs)

#define ASSERT_DATA_EQ(lhs, rhs) \
compare_data(lhs,rhs)

#define ASSERT_PREPROCESS_CHANNEL_EQ(lhs, rhs) \
compare_preprocess(lhs,rhs)

#define ASSERT_PREPROCESS_INFO_EQ(lhs, rhs) \
compare_preprocess_info(lhs,rhs)

#define ASSERT_OUTPUTS_INFO_EQ(lhs, rhs) \
compare_outputs_info(lhs,rhs)

#define ASSERT_INPUTS_INFO_EQ(lhs, rhs) \
compare_inputs_info(lhs,rhs)

#define ASSERT_STRINGEQ(lhs, rhs) \
compare_cpp_strings(lhs,rhs)



inline void compare_blob(InferenceEngine::Blob::Ptr lhs, InferenceEngine::Blob::Ptr rhs) {
    ASSERT_EQ(lhs.get(), rhs.get());
    //TODO: add blob specific comparison for general case
}

inline void compare_dims(const InferenceEngine::SizeVector & lhs, const InferenceEngine::SizeVector & rhs) {
    ASSERT_EQ(lhs.size(), rhs.size());
    for(int i=0;i<lhs.size();i++) {
        ASSERT_EQ(lhs[i], rhs[i]);
    }
}

inline void compare_data(const InferenceEngine::Data & lhs, const InferenceEngine::Data & rhs) {
    ASSERT_DIMS_EQ(lhs.getDims(), rhs.getDims());
    ASSERT_STREQ(lhs.getName().c_str(), rhs.getName().c_str());
    ASSERT_EQ(lhs.getPrecision(), rhs.getPrecision());
}

inline void compare_preprocess(const InferenceEngine::PreProcessChannel & lhs, const InferenceEngine::PreProcessChannel & rhs) {
    ASSERT_FLOAT_EQ(lhs.meanValue, rhs.meanValue);
    ASSERT_FLOAT_EQ(lhs.stdScale, rhs.stdScale);
    ASSERT_BLOB_EQ(lhs.meanData, rhs.meanData);
}

inline void compare_preprocess_info(const InferenceEngine::PreProcessInfo & lhs, const InferenceEngine::PreProcessInfo & rhs) {
    ASSERT_EQ(lhs.getMeanVariant(), rhs.getMeanVariant());
    ASSERT_EQ(lhs.getNumberOfChannels(), rhs.getNumberOfChannels());
    for(int i=0; i < lhs.getNumberOfChannels(); i++) {
        ASSERT_PREPROCESS_CHANNEL_EQ(*lhs[i].get(), *rhs[i].get());
    }
}

inline void compare_outputs_info(const InferenceEngine::OutputsDataMap & lhs, const InferenceEngine::OutputsDataMap & rhs) {
    ASSERT_EQ(lhs.size(), rhs.size());
    auto i = lhs.begin();
    auto j = rhs.begin();

    for (int k =0; k != lhs.size(); k++, i++, j++) {
        ASSERT_STREQ(i->first.c_str(), j->first.c_str());
        ASSERT_DATA_EQ(*i->second.get(), *j->second.get());
    }
}

inline void compare_inputs_info (const InferenceEngine::InputsDataMap & lhs, const InferenceEngine::InputsDataMap & rhs) {
    ASSERT_EQ(lhs.size(), rhs.size());
    auto i = lhs.begin();
    auto j = rhs.begin();

    for (int k =0; k != lhs.size(); k++, i++, j++) {
        ASSERT_STREQ(i->first.c_str(), j->first.c_str());
        ASSERT_DIMS_EQ(i->second->getDims(), j->second->getDims());
        ASSERT_PREPROCESS_INFO_EQ(i->second->getPreProcess(), j->second->getPreProcess());
        ASSERT_DATA_EQ(*i->second->getInputData().get(), *j->second->getInputData().get());
    }
}

inline void compare_cpp_strings(const std::string & lhs, const std::string &rhs) {
    ASSERT_STREQ(lhs.c_str(), rhs.c_str());
}