1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
|
// Copyright (C) 2018 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <string>
#include <mkldnn.hpp>
#include <mkldnn/desc_iterator.hpp>
class MKLDNNDescriptor {
public:
MKLDNNDescriptor() {}
explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::batch_normalization_forward::desc> desc);
operator std::shared_ptr<mkldnn::batch_normalization_forward::desc>();
explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_forward::desc> desc);
operator std::shared_ptr<mkldnn::convolution_forward::desc>();
explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_relu_forward::desc> desc);
operator std::shared_ptr<mkldnn::convolution_relu_forward::desc>();
MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_backward_data::desc> desc,
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> prim);
operator std::shared_ptr<mkldnn::convolution_backward_data::desc>();
operator std::shared_ptr<mkldnn::convolution_forward::primitive_desc>();
explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::inner_product_forward::desc> desc);
operator std::shared_ptr<mkldnn::inner_product_forward::desc>();
explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::lrn_forward::desc> desc);
operator std::shared_ptr<mkldnn::lrn_forward::desc>();
explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::pooling_forward::desc> desc);
operator std::shared_ptr<mkldnn::pooling_forward::desc>();
explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::relu_forward::desc> desc);
operator std::shared_ptr<mkldnn::relu_forward::desc>();
explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::roi_pooling_forward::desc> desc);
operator std::shared_ptr<mkldnn::roi_pooling_forward::desc>();
explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::softmax_forward::desc> desc);
operator std::shared_ptr<mkldnn::softmax_forward::desc>();
explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::depthwise_forward::desc> desc);
operator std::shared_ptr<mkldnn::depthwise_forward::desc>();
mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::engine &engine,
const mkldnn::primitive_attr &attr = mkldnn::primitive_attr()) const;
mkldnn::primitive_desc_iterator * createPrimitiveDescriptorIteratorPtr(const mkldnn::engine &engine) const;
size_t outputNumbers() const;
size_t inputNumbers() const;
operator bool();
private:
class IDesc {
public:
virtual ~IDesc() {}
virtual mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::engine &engine) const = 0;
virtual mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
const mkldnn::engine &engine) const = 0;
virtual mkldnn::primitive_desc_iterator *createPrimitiveDescriptorIteratorPtr(const mkldnn::engine &engine) const = 0;
};
template <class T>
class DescFwdImpl: public IDesc {
std::shared_ptr<T> desc;
public:
explicit DescFwdImpl(std::shared_ptr<T> d) : desc(d) {}
mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::engine &engine) const override {
return mkldnn::primitive_desc_iterator(*desc, engine);
}
mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
const mkldnn::engine &engine) const override {
return mkldnn::primitive_desc_iterator(*desc, attr, engine);
}
mkldnn::primitive_desc_iterator *createPrimitiveDescriptorIteratorPtr(const mkldnn::engine &engine) const override {
return new mkldnn::primitive_desc_iterator(*desc, engine);
}
std::shared_ptr<T>& getPtr() {
return desc;
}
};
template <class T, class P>
class DescBwdImpl: public IDesc {
std::shared_ptr<T> desc;
std::shared_ptr<P> prim;
public:
DescBwdImpl(std::shared_ptr<T> d, std::shared_ptr<P> p) : desc(d), prim(p) {}
mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::engine &engine) const override {
return mkldnn::primitive_desc_iterator(*desc, engine, *prim);
}
mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
const mkldnn::engine &engine) const override {
return mkldnn::primitive_desc_iterator(*desc, attr, engine, *prim);
}
mkldnn::primitive_desc_iterator *createPrimitiveDescriptorIteratorPtr(const mkldnn::engine &engine) const override {
return new mkldnn::primitive_desc_iterator(*desc, engine, *prim);
}
std::shared_ptr<T>& getPtr() {
return desc;
}
std::shared_ptr<P>& getPrimPtr() {
return prim;
}
};
std::shared_ptr<IDesc> desc;
};
|