summaryrefslogtreecommitdiff
path: root/inference-engine/src/mkldnn_plugin/mkldnn_descriptor.h
blob: d897ed03aa67f4d419fc25968b2f45cece4ffc18 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Copyright (C) 2018 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <string>
#include <mkldnn.hpp>
#include <mkldnn/desc_iterator.hpp>

class MKLDNNDescriptor {
public:
    MKLDNNDescriptor() {}
    explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::batch_normalization_forward::desc> desc);
    operator std::shared_ptr<mkldnn::batch_normalization_forward::desc>();

    explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_forward::desc> desc);
    operator std::shared_ptr<mkldnn::convolution_forward::desc>();

    explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_relu_forward::desc> desc);
    operator std::shared_ptr<mkldnn::convolution_relu_forward::desc>();

    MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_backward_data::desc> desc,
                     std::shared_ptr<mkldnn::convolution_forward::primitive_desc> prim);
    operator std::shared_ptr<mkldnn::convolution_backward_data::desc>();
    operator std::shared_ptr<mkldnn::convolution_forward::primitive_desc>();

    explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::inner_product_forward::desc> desc);
    operator std::shared_ptr<mkldnn::inner_product_forward::desc>();

    explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::lrn_forward::desc> desc);
    operator std::shared_ptr<mkldnn::lrn_forward::desc>();

    explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::pooling_forward::desc> desc);
    operator std::shared_ptr<mkldnn::pooling_forward::desc>();

    explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::relu_forward::desc> desc);
    operator std::shared_ptr<mkldnn::relu_forward::desc>();

    explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::roi_pooling_forward::desc> desc);
    operator std::shared_ptr<mkldnn::roi_pooling_forward::desc>();

    explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::softmax_forward::desc> desc);
    operator std::shared_ptr<mkldnn::softmax_forward::desc>();

    explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::depthwise_forward::desc> desc);
    operator std::shared_ptr<mkldnn::depthwise_forward::desc>();

    mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::engine &engine,
            const mkldnn::primitive_attr &attr = mkldnn::primitive_attr()) const;
    mkldnn::primitive_desc_iterator * createPrimitiveDescriptorIteratorPtr(const mkldnn::engine &engine) const;

    size_t outputNumbers() const;
    size_t inputNumbers() const;

    operator bool();

private:
    class IDesc {
    public:
        virtual ~IDesc() {}
        virtual mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::engine &engine) const = 0;
        virtual mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
                                                                                  const mkldnn::engine &engine) const = 0;
        virtual mkldnn::primitive_desc_iterator *createPrimitiveDescriptorIteratorPtr(const mkldnn::engine &engine) const = 0;
    };

    template <class T>
    class DescFwdImpl: public IDesc {
        std::shared_ptr<T> desc;
    public:
        explicit DescFwdImpl(std::shared_ptr<T> d) : desc(d) {}

        mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::engine &engine) const override {
            return mkldnn::primitive_desc_iterator(*desc, engine);
        }

        mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
                                                                          const mkldnn::engine &engine) const override {
            return mkldnn::primitive_desc_iterator(*desc, attr, engine);
        }

        mkldnn::primitive_desc_iterator *createPrimitiveDescriptorIteratorPtr(const mkldnn::engine &engine) const override {
            return new mkldnn::primitive_desc_iterator(*desc, engine);
        }

        std::shared_ptr<T>& getPtr() {
            return desc;
        }
    };


    template <class T, class P>
    class DescBwdImpl: public IDesc {
        std::shared_ptr<T> desc;
        std::shared_ptr<P> prim;

    public:
        DescBwdImpl(std::shared_ptr<T> d, std::shared_ptr<P> p) : desc(d), prim(p) {}

        mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::engine &engine) const override {
            return mkldnn::primitive_desc_iterator(*desc, engine, *prim);
        }

        mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
                                                                          const mkldnn::engine &engine) const override {
            return mkldnn::primitive_desc_iterator(*desc, attr, engine, *prim);
        }

        mkldnn::primitive_desc_iterator *createPrimitiveDescriptorIteratorPtr(const mkldnn::engine &engine) const override {
            return new mkldnn::primitive_desc_iterator(*desc, engine, *prim);
        }

        std::shared_ptr<T>& getPtr() {
            return desc;
        }

        std::shared_ptr<P>& getPrimPtr() {
            return prim;
        }
    };

    std::shared_ptr<IDesc> desc;
};