summaryrefslogtreecommitdiff
path: root/inference-engine/src/mkldnn_plugin/mkldnn/cpu_prim_layer.h
blob: b3ad3c0c5ef8eebf52f0e922d7189398d7e7d631 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
// Copyright (C) 2018 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "inference_engine.hpp"
#include "prim_layer.h"
#include "mkldnn.hpp"
#include <memory>

using namespace InferenceEngine;
using namespace mkldnn;

namespace MKLDNNPlugin {

class CpuPrimLayer : public PrimLayer {
    friend class CpuEngine;

    mkldnn::engine eng;
    std::shared_ptr<mkldnn::primitive> prim;

public:
    explicit CpuPrimLayer(engine eng) : eng(eng) {}
};

template<typename LYR>
class Layer : public CpuPrimLayer {
    typename LYR::desc desc;
    typename LYR::primitive_desc prim_desc;

public:
    Layer(typename LYR::desc desc, engine eng) :
            CpuPrimLayer(eng),
            desc(desc),
            prim_desc(desc, eng) {}

    friend class CpuEngine;
};

class ReorderLayer : public CpuPrimLayer {
    reorder::primitive_desc prim_desc;

public:
    ReorderLayer(reorder::primitive_desc desc, engine eng) :
            CpuPrimLayer(eng),
            prim_desc(desc) {}

    friend class CpuEngine;
};
}  // namespace MKLDNNPlugin