summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/common/convolution_pd.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/common/convolution_pd.hpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/common/convolution_pd.hpp21
1 files changed, 20 insertions, 1 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/common/convolution_pd.hpp b/inference-engine/thirdparty/mkl-dnn/src/common/convolution_pd.hpp
index b89dd4052..c6e45b9f8 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/common/convolution_pd.hpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/common/convolution_pd.hpp
@@ -131,6 +131,12 @@ struct _convolution_fwd_pd_t: public primitive_desc_t {
inline int ndims() const { return cdesc_().src_desc.ndims; }
+ bool has_zero_dim_memory() const {
+ return false
+ || memory_desc_wrapper(cdesc_().src_desc).has_zero_dim()
+ || memory_desc_wrapper(cdesc_().dst_desc).has_zero_dim();
+ }
+
protected:
base_desc_t desc_;
const _convolution_fwd_pd_t *hint_fwd_pd_;
@@ -184,7 +190,7 @@ struct convolution_bwd_data_pd_t: public primitive_desc_t {
virtual const memory_pd_t *output_pd(int index = 0) const override
{ return index == 0 ? diff_src_pd() : nullptr; }
- virtual int n_inputs() const override { return 2; }
+ virtual int n_inputs() const override { return 2 + with_bias(); }
virtual int n_outputs() const override { return 1; }
virtual status_t query(query_t what, int idx, void *result) const override
@@ -243,6 +249,13 @@ struct convolution_bwd_data_pd_t: public primitive_desc_t {
{ return desc_.weights_desc.ndims == desc_.diff_src_desc.ndims + 1; }
inline int ndims() const { return desc_.diff_src_desc.ndims; }
+ virtual bool support_bias() const { return false; }
+
+ bool has_zero_dim_memory() const {
+ return false
+ || memory_desc_wrapper(desc_.diff_src_desc).has_zero_dim()
+ || memory_desc_wrapper(desc_.diff_dst_desc).has_zero_dim();
+ }
protected:
convolution_desc_t desc_;
@@ -346,6 +359,12 @@ struct convolution_bwd_weights_pd_t: public primitive_desc_t {
inline int ndims() const { return desc_.src_desc.ndims; }
+ bool has_zero_dim_memory() const {
+ return false
+ || memory_desc_wrapper(desc_.src_desc).has_zero_dim()
+ || memory_desc_wrapper(desc_.diff_dst_desc).has_zero_dim();
+ }
+
protected:
convolution_desc_t desc_;
const convolution_fwd_pd_t *hint_fwd_pd_;