diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_sse42_convolution.hpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/src/cpu/jit_sse42_convolution.hpp | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_sse42_convolution.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_sse42_convolution.hpp index 8aca31d2f..192349588 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_sse42_convolution.hpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_sse42_convolution.hpp @@ -56,7 +56,7 @@ struct _jit_sse42_convolution_fwd_t: public cpu_primitive_t { this->cdesc_().src_desc.data_type, this->cdesc_().weights_desc.data_type, this->cdesc_().dst_desc.data_type) - && utils::implication(this->with_bias(), + && IMPLICATION(this->with_bias(), data_type::f32 == this->cdesc_().bias_desc.data_type); if (!ok) return status::unimplemented; @@ -91,13 +91,18 @@ struct _jit_sse42_convolution_fwd_t: public cpu_primitive_t { const bool flat = this->IC() == 3 || this->IC() == 1; if (this->src_pd_.desc()->format == any) - CHECK(this->src_pd_.set_format(flat ? nchw : nChw8c)); + CHECK(this->src_pd_.set_format(flat + ? utils::pick(this->ndims() - 3, ncw, nchw) + : utils::pick(this->ndims() - 3, nCw8c, nChw8c))); if (this->dst_pd_.desc()->format == any) - CHECK(this->dst_pd_.set_format(nChw8c)); + CHECK(this->dst_pd_.set_format(utils::pick(this->ndims() - 3, + nCw8c, nChw8c))); if (this->weights_pd_.desc()->format == any) CHECK(this->weights_pd_.set_format(this->with_groups() - ? (flat ? gOhwi8o : gOIhw8i8o) - : (flat ? Ohwi8o : OIhw8i8o))); + ? utils::pick(2 * this->ndims() - 6 + flat, gOIw8i8o, + gOwi8o, gOIhw8i8o, gOhwi8o) + : utils::pick(2 * this->ndims() - 6 + flat, OIw8i8o, Owi8o, + OIhw8i8o, Ohwi8o))); if (this->bias_pd_.desc()->format == any) CHECK(this->bias_pd_.set_format(x)); return status::success; @@ -108,7 +113,6 @@ struct _jit_sse42_convolution_fwd_t: public cpu_primitive_t { const output_vector &outputs) : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), dw_conv_buffer_size_(0), dw_conv_buffer_(nullptr), padded_bias_(nullptr), dw_padded_bias_(nullptr) - { kernel_ = new jit_sse42_conv_fwd_kernel_f32(conf_.jcp_, *conf_.attr()); if (conf_.jcp_.with_dw_conv) { |