diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/common/convolution_pd.hpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/src/common/convolution_pd.hpp | 21 |
1 files changed, 20 insertions, 1 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/common/convolution_pd.hpp b/inference-engine/thirdparty/mkl-dnn/src/common/convolution_pd.hpp index b89dd4052..c6e45b9f8 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/common/convolution_pd.hpp +++ b/inference-engine/thirdparty/mkl-dnn/src/common/convolution_pd.hpp @@ -131,6 +131,12 @@ struct _convolution_fwd_pd_t: public primitive_desc_t { inline int ndims() const { return cdesc_().src_desc.ndims; } + bool has_zero_dim_memory() const { + return false + || memory_desc_wrapper(cdesc_().src_desc).has_zero_dim() + || memory_desc_wrapper(cdesc_().dst_desc).has_zero_dim(); + } + protected: base_desc_t desc_; const _convolution_fwd_pd_t *hint_fwd_pd_; @@ -184,7 +190,7 @@ struct convolution_bwd_data_pd_t: public primitive_desc_t { virtual const memory_pd_t *output_pd(int index = 0) const override { return index == 0 ? diff_src_pd() : nullptr; } - virtual int n_inputs() const override { return 2; } + virtual int n_inputs() const override { return 2 + with_bias(); } virtual int n_outputs() const override { return 1; } virtual status_t query(query_t what, int idx, void *result) const override @@ -243,6 +249,13 @@ struct convolution_bwd_data_pd_t: public primitive_desc_t { { return desc_.weights_desc.ndims == desc_.diff_src_desc.ndims + 1; } inline int ndims() const { return desc_.diff_src_desc.ndims; } + virtual bool support_bias() const { return false; } + + bool has_zero_dim_memory() const { + return false + || memory_desc_wrapper(desc_.diff_src_desc).has_zero_dim() + || memory_desc_wrapper(desc_.diff_dst_desc).has_zero_dim(); + } protected: convolution_desc_t desc_; @@ -346,6 +359,12 @@ struct convolution_bwd_weights_pd_t: public primitive_desc_t { inline int ndims() const { return desc_.src_desc.ndims; } + bool has_zero_dim_memory() const { + return false + || memory_desc_wrapper(desc_.src_desc).has_zero_dim() + || memory_desc_wrapper(desc_.diff_dst_desc).has_zero_dim(); + } + protected: convolution_desc_t desc_; const convolution_fwd_pd_t *hint_fwd_pd_; |