summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_sse42_convolution.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_sse42_convolution.hpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/jit_sse42_convolution.hpp16
1 files changed, 10 insertions, 6 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_sse42_convolution.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_sse42_convolution.hpp
index 8aca31d2f..192349588 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_sse42_convolution.hpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_sse42_convolution.hpp
@@ -56,7 +56,7 @@ struct _jit_sse42_convolution_fwd_t: public cpu_primitive_t {
this->cdesc_().src_desc.data_type,
this->cdesc_().weights_desc.data_type,
this->cdesc_().dst_desc.data_type)
- && utils::implication(this->with_bias(),
+ && IMPLICATION(this->with_bias(),
data_type::f32 == this->cdesc_().bias_desc.data_type);
if (!ok) return status::unimplemented;
@@ -91,13 +91,18 @@ struct _jit_sse42_convolution_fwd_t: public cpu_primitive_t {
const bool flat = this->IC() == 3 || this->IC() == 1;
if (this->src_pd_.desc()->format == any)
- CHECK(this->src_pd_.set_format(flat ? nchw : nChw8c));
+ CHECK(this->src_pd_.set_format(flat
+ ? utils::pick(this->ndims() - 3, ncw, nchw)
+ : utils::pick(this->ndims() - 3, nCw8c, nChw8c)));
if (this->dst_pd_.desc()->format == any)
- CHECK(this->dst_pd_.set_format(nChw8c));
+ CHECK(this->dst_pd_.set_format(utils::pick(this->ndims() - 3,
+ nCw8c, nChw8c)));
if (this->weights_pd_.desc()->format == any)
CHECK(this->weights_pd_.set_format(this->with_groups()
- ? (flat ? gOhwi8o : gOIhw8i8o)
- : (flat ? Ohwi8o : OIhw8i8o)));
+ ? utils::pick(2 * this->ndims() - 6 + flat, gOIw8i8o,
+ gOwi8o, gOIhw8i8o, gOhwi8o)
+ : utils::pick(2 * this->ndims() - 6 + flat, OIw8i8o, Owi8o,
+ OIhw8i8o, Ohwi8o)));
if (this->bias_pd_.desc()->format == any)
CHECK(this->bias_pd_.set_format(x));
return status::success;
@@ -108,7 +113,6 @@ struct _jit_sse42_convolution_fwd_t: public cpu_primitive_t {
const output_vector &outputs)
: cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd),
dw_conv_buffer_size_(0), dw_conv_buffer_(nullptr), padded_bias_(nullptr), dw_padded_bias_(nullptr)
-
{
kernel_ = new jit_sse42_conv_fwd_kernel_f32(conf_.jcp_, *conf_.attr());
if (conf_.jcp_.with_dw_conv) {