summaryrefslogtreecommitdiff
path: root/inference-engine/src/inference_engine/ade_util.cpp
blob: 041c5655e90dbd43e228099c1534172a50672905 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// Copyright (C) 2018 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ade_util.hpp"

#include <unordered_map>
#include <utility>

#include <ie_icnn_network.hpp>
#include <ie_util_internal.hpp>
#include <ie_layers.h>

#include <ade/util/algorithm.hpp>
#include <ade/graph.hpp>
#include <ade/typed_graph.hpp>

namespace InferenceEngine {
namespace {
using VisitedLayersMap = std::unordered_map<CNNLayer::Ptr, ade::NodeHandle>;
using TGraph = ade::TypedGraph<CNNLayerMetadata>;

void translateVisitLayer(VisitedLayersMap& visited,
                TGraph& gr,
                const ade::NodeHandle& prevNode,
                const CNNLayer::Ptr& layer) {
    assert(nullptr != layer);;
    assert(!ade::util::contains(visited, layer));
    auto node = gr.createNode();
    gr.metadata(node).set(CNNLayerMetadata{layer});
    if (nullptr != prevNode) {
        gr.link(prevNode, node);
    }
    visited.insert({layer, node});
    for (auto&& data : layer->outData) {
        for (auto&& layerIt : data->inputTo) {
            auto nextLayer = layerIt.second;
            auto it = visited.find(nextLayer);
            if (visited.end() == it) {
                translateVisitLayer(visited, gr, node, nextLayer);
            } else {
                gr.link(node, it->second);
            }
        }
    }
}
}  // namespace

void translateNetworkToAde(ade::Graph& gr, ICNNNetwork& network) {
    TGraph tgr(gr);
    VisitedLayersMap visited;
    for (auto& data : getRootDataObjects(network)) {
        assert(nullptr != data);
        for (auto& layerIt : data->getInputTo()) {
            auto layer = layerIt.second;
            assert(nullptr != layer);
            if (!ade::util::contains(visited, layer)) {
                translateVisitLayer(visited, tgr, nullptr, layer);
            }
        }
    }
}

const char* CNNLayerMetadata::name() {
    return "CNNLayerMetadata";
}

}  // namespace InferenceEngine