summaryrefslogtreecommitdiff
path: root/tests/nnapi/src/TestGenerated_common.cpp
blob: 2d17efe51313f13c73e9e4a4d47a280896bdf5b6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/*
 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
 *
 * generated/all_generated_tests.cpp that is included in the bottom line
 * in file is also modified.
 *
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Top level driver for models and examples generated by test_generator.py

#include "NeuralNetworksWrapper.h"
#include "TestHarness.h"

#include <gtest/gtest.h>
#include <cassert>
#include <cmath>
#include <fstream>
#include <iostream>
#include <map>

// Uncomment the following line to generate DOT graphs.
//
// #define GRAPH GRAPH

namespace generated_tests {
using namespace nnfw::rt::wrapper;

template <typename T>
static void print(std::ostream& os, const MixedTyped& test) {
    // dump T-typed inputs
    for_each<T>(test, [&os](int idx, const std::vector<T>& f) {
        os << "    aliased_output" << idx << ": [";
        for (size_t i = 0; i < f.size(); ++i) {
            os << (i == 0 ? "" : ", ") << +f[i];
        }
        os << "],\n";
    });
}

static void printAll(std::ostream& os, const MixedTyped& test) {
    print<float>(os, test);
    print<int32_t>(os, test);
    print<uint8_t>(os, test);
}

// Test driver for those generated from ml/nn/runtime/test/spec
static void execute(std::function<void(Model*)> createModel,
             std::function<bool(int)> isIgnored,
             std::vector<MixedTypedExampleType>& examples,
             std::string dumpFile = "") {
    Model model;
    createModel(&model);
    model.finish();
    bool dumpToFile = !dumpFile.empty();

    std::ofstream s;
    if (dumpToFile) {
        s.open(dumpFile, std::ofstream::trunc);
        ASSERT_TRUE(s.is_open());
    }

    int exampleNo = 0;
    Compilation compilation(&model);
    compilation.finish();

    const float fpRange = 1e-5f;
    for (auto& example : examples) {
        SCOPED_TRACE(exampleNo);
        // TODO: We leave it as a copy here.
        // Should verify if the input gets modified by the test later.
        MixedTyped inputs = example.first;
        const MixedTyped& golden = example.second;

        Execution execution(&compilation);

        // Set all inputs
        for_all(inputs, [&execution](int idx, const void* p, size_t s) {
            const void* buffer = s == 0 ? nullptr : p;
            ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, buffer, s));
        });

        MixedTyped test;
        // Go through all typed outputs
        resize_accordingly(golden, test);
        for_all(test, [&execution](int idx, void* p, size_t s) {
            void* buffer = s == 0 ? nullptr : p;
            ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, buffer, s));
        });

        Result r = execution.compute();
        ASSERT_EQ(Result::NO_ERROR, r);

        // Dump all outputs for the slicing tool
        if (dumpToFile) {
            s << "output" << exampleNo << " = {\n";
            printAll(s, test);
            // all outputs are done
            s << "}\n";
        }

        // Filter out don't cares
        MixedTyped filteredGolden = filter(golden, isIgnored);
        MixedTyped filteredTest = filter(test, isIgnored);
        // We want "close-enough" results for float

        compare(filteredGolden, filteredTest, fpRange);
        exampleNo++;
    }
}

};  // namespace generated_tests

using namespace nnfw::rt::wrapper;

// Mixed-typed examples
typedef generated_tests::MixedTypedExampleType MixedTypedExample;

class GeneratedTests : public ::testing::Test {
   protected:
    virtual void SetUp() {}
};