summaryrefslogtreecommitdiff
path: root/aten
diff options
context:
space:
mode:
authorMa Mingfei <mingfei.ma@intel.com>2018-03-30 06:25:07 +0800
committerSoumith Chintala <soumith@gmail.com>2018-03-29 15:25:07 -0700
commitf8270c0225e19403038aec2d8af2697a2b5326ec (patch)
tree7dbc688871d3943a2ec3f2d896e8d4e0637d990b /aten
parente4c0bb1809fd9bf9161392bfff7d06092adc224d (diff)
downloadpytorch-f8270c0225e19403038aec2d8af2697a2b5326ec.tar.gz
pytorch-f8270c0225e19403038aec2d8af2697a2b5326ec.tar.bz2
pytorch-f8270c0225e19403038aec2d8af2697a2b5326ec.zip
Enable MKLDNN convolution forward and backward (#6062)
* Enable MKLDNN convolution forward and backward * minor change * fix mkldnn build error when building ATen standalone
Diffstat (limited to 'aten')
-rw-r--r--aten/CMakeLists.txt14
-rw-r--r--aten/cmake/FindMKLDNN.cmake32
-rw-r--r--aten/src/ATen/CMakeLists.txt12
-rw-r--r--aten/src/ATen/Config.h.in1
-rw-r--r--aten/src/ATen/mkldnn/Runtime.cpp5
-rw-r--r--aten/src/ATen/mkldnn/Runtime.h49
-rw-r--r--aten/src/ATen/native/Convolution.cpp27
-rw-r--r--aten/src/ATen/native/mkldnn/Conv.cpp441
-rw-r--r--aten/src/ATen/native/native_functions.yaml12
9 files changed, 591 insertions, 2 deletions
diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt
index 16e7cc6790..13e4abca94 100644
--- a/aten/CMakeLists.txt
+++ b/aten/CMakeLists.txt
@@ -460,6 +460,20 @@ ELSE()
set(AT_CUDNN_ENABLED 1)
ENDIF()
+if(NO_MKLDNN)
+ message("disabling MKLDNN because NO_MKLDNN is set")
+ set(AT_MKLDNN_ENABLED 0)
+else()
+ find_package(MKLDNN)
+ if(NOT MKLDNN_FOUND)
+ message(STATUS "MKLDNN not found. Compiling without MKLDNN support")
+ set(AT_MKLDNN_ENABLED 0)
+ else()
+ INCLUDE_DIRECTORIES(${MKLDNN_INCLUDE_DIRS})
+ set(AT_MKLDNN_ENABLED 1)
+ endif()
+endif()
+
if(NO_NNPACK)
message("disabling NNPACK because NO_NNPACK is set")
set(AT_NNPACK_ENABLED 0)
diff --git a/aten/cmake/FindMKLDNN.cmake b/aten/cmake/FindMKLDNN.cmake
new file mode 100644
index 0000000000..0862d5a3ac
--- /dev/null
+++ b/aten/cmake/FindMKLDNN.cmake
@@ -0,0 +1,32 @@
+# - Try to find MKLDNN
+#
+# The following variables are optionally searched for defaults
+# MKLDNN_ROOT_DIR: Base directory where all MKLDNN components are found
+#
+# The following are set after configuration is done:
+# MKLDNN_FOUND
+# MKLDNN_INCLUDE_DIRS
+# MKLDNN_LIBRARIES
+# MKLDNN_LIBRARY_DIRS
+
+include(FindPackageHandleStandardArgs)
+
+set(MKLDNN_ROOT_DIR "" CACHE PATH "Folder contains Intel MKLDNN")
+
+find_path(MKLDNN_INCLUDE_DIR mkldnn.h
+ HINTS ${MKLDNN_ROOT_DIR}
+ PATH_SUFFIXES include)
+
+find_library(MKLDNN_LIBRARY mkldnn
+ HINTS ${MKLDNN_LIB_DIR} ${MKLDNN_ROOT_DIR}
+ PATH_SUFFIXES lib lib64)
+
+find_package_handle_standard_args(
+ MKLDNN DEFAULT_MSG MKLDNN_INCLUDE_DIR MKLDNN_LIBRARY)
+
+if(MKLDNN_FOUND)
+ set(MKLDNN_INCLUDE_DIRS ${MKLDNN_INCLUDE_DIR})
+ set(MKLDNN_LIBRARIES ${MKLDNN_LIBRARY})
+ message(STATUS "Found MKLDNN (include: ${MKLDNN_INCLUDE_DIR}, library: ${MKLDNN_LIBRARY})")
+ mark_as_advanced(MKLDNN_ROOT_DIR MKLDNN_LIBRARY MKLDNN_INCLUDE_DIR)
+endif()
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
index daf986c6bd..27f1d98461 100644
--- a/aten/src/ATen/CMakeLists.txt
+++ b/aten/src/ATen/CMakeLists.txt
@@ -149,6 +149,7 @@ FILE(GLOB native_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/*.cpp")
FILE(GLOB native_cudnn_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/cudnn/*.cpp")
FILE(GLOB native_cuda_cu RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/cuda/*.cu")
FILE(GLOB native_mkl_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/mkl/*.cpp")
+FILE(GLOB native_mkldnn_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/mkldnn/*.cpp")
FILE(GLOB_RECURSE cuda_h
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
@@ -156,6 +157,7 @@ FILE(GLOB_RECURSE cuda_h
FILE(GLOB cudnn_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cudnn/*.cpp")
FILE(GLOB mkl_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "mkl/*.cpp")
+FILE(GLOB mkldnn_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "mkldnn/*.cpp")
FILE(GLOB all_python RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.py")
@@ -201,7 +203,7 @@ ADD_CUSTOM_TARGET(aten_files_are_generated
)
-SET(all_cpp ${base_cpp} ${native_cpp} ${native_cudnn_cpp} ${native_mkl_cpp} ${generated_cpp} ${ATen_CPU_SRCS} ${cpu_kernel_cpp})
+SET(all_cpp ${base_cpp} ${native_cpp} ${native_cudnn_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${generated_cpp} ${ATen_CPU_SRCS} ${cpu_kernel_cpp})
INCLUDE_DIRECTORIES(${ATen_CPU_INCLUDE})
IF(NOT NO_CUDA)
@@ -218,6 +220,10 @@ IF(NOT NO_CUDA)
ENDIF()
endif()
+IF(AT_MKLDNN_ENABLED)
+ SET(all_cpp ${all_cpp} ${mkldnn_cpp})
+ENDIF()
+
filter_list(generated_h generated_cpp "\\.h$")
INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}/..)
@@ -315,6 +321,9 @@ if (NNPACK_FOUND)
target_link_libraries(ATen ${NNPACK_LIBRARIES})
endif(NNPACK_FOUND)
+if(MKLDNN_FOUND)
+ target_link_libraries(ATen ${MKLDNN_LIBRARIES})
+endif(MKLDNN_FOUND)
# ---[ Configure cpuinfo
IF(NOT TARGET cpuinfo)
@@ -326,7 +335,6 @@ IF(NOT TARGET cpuinfo)
ENDIF()
TARGET_LINK_LIBRARIES(ATen cpuinfo)
-
IF(CUDA_FOUND)
TARGET_LINK_LIBRARIES(ATen
${CUDA_LIBRARIES}
diff --git a/aten/src/ATen/Config.h.in b/aten/src/ATen/Config.h.in
index d62468fd70..1ab0ec9162 100644
--- a/aten/src/ATen/Config.h.in
+++ b/aten/src/ATen/Config.h.in
@@ -6,6 +6,7 @@
#define AT_CUDA_ENABLED() @AT_CUDA_ENABLED@
#define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@
+#define AT_MKLDNN_ENABLED() @AT_MKLDNN_ENABLED@
#define AT_NNPACK_ENABLED() @AT_NNPACK_ENABLED@
#define AT_MKL_ENABLED() @AT_MKL_ENABLED@
diff --git a/aten/src/ATen/mkldnn/Runtime.cpp b/aten/src/ATen/mkldnn/Runtime.cpp
new file mode 100644
index 0000000000..54f999ed14
--- /dev/null
+++ b/aten/src/ATen/mkldnn/Runtime.cpp
@@ -0,0 +1,5 @@
+#include "Runtime.h"
+
+namespace at { namespace native {
+
+}} // namespace at::native
diff --git a/aten/src/ATen/mkldnn/Runtime.h b/aten/src/ATen/mkldnn/Runtime.h
new file mode 100644
index 0000000000..c58ef2c56f
--- /dev/null
+++ b/aten/src/ATen/mkldnn/Runtime.h
@@ -0,0 +1,49 @@
+#pragma once
+
+#include <mkldnn.hpp>
+
+using namespace mkldnn;
+
+namespace at { namespace native {
+
+// CpuEngine singleton
+struct CpuEngine {
+ static CpuEngine& Instance() {
+ static CpuEngine myInstance;
+ return myInstance;
+ }
+ engine& get_engine() {
+ return _cpu_engine;
+ }
+ CpuEngine(CpuEngine const&) = delete;
+ CpuEngine& operator=(CpuEngine const&) = delete;
+
+protected:
+ CpuEngine():_cpu_engine(mkldnn::engine::cpu, 0) {}
+ ~CpuEngine() {}
+
+private:
+ engine _cpu_engine;
+};
+
+// Stream singleton
+struct Stream {
+ static Stream& Instance() {
+ static Stream myInstance;
+ return myInstance;
+ };
+ stream& get_stream() {
+ return _cpu_stream;
+ }
+ Stream(Stream const&) = delete;
+ Stream& operator=(Stream const&) = delete;
+
+protected:
+ Stream():_cpu_stream(mkldnn::stream::kind::eager) {}
+ ~Stream() {}
+
+private:
+ stream _cpu_stream;
+};
+
+}} // namespace at::native
diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp
index b65759249a..33f7e05d26 100644
--- a/aten/src/ATen/native/Convolution.cpp
+++ b/aten/src/ATen/native/Convolution.cpp
@@ -32,6 +32,7 @@ struct ConvParams {
bool is_padding_neg() const;
void view1d_as_2d();
bool use_cudnn(const at::Tensor& input) const;
+ bool use_mkldnn(const at::Tensor& input) const;
bool use_nnpack(const at::Tensor& input) const;
bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
};
@@ -130,6 +131,17 @@ auto ConvParams::use_cudnn(const at::Tensor& input) const -> bool {
return false;
}
+auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool {
+#if AT_MKLDNN_ENABLED()
+ return input.type().backend() == kCPU &&
+ input.type().scalarType() == kFloat && // only on CPU Float Tensors
+ !is_dilated() && // doesn't support dilation
+ !transposed && // or transposed tensors
+ input.ndimension() == 4 && // must be in NCHW format
+ groups == 1;
+#endif
+ return false;
+}
auto ConvParams::use_nnpack(const at::Tensor& input) const -> bool {
#if AT_NNPACK_ENABLED()
return input.type().backend() == kCPU &&
@@ -371,6 +383,21 @@ at::Tensor _convolution(
params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
}
#endif
+ } else if (params.use_mkldnn(input)) {
+#if AT_MKLDNN_ENABLED()
+ if (input.type() != weight.type()){
+ std::stringstream ss;
+ ss << "Input type (" << input.toString() << ") and weight type (" << weight.toString() << ") should be the same";
+ throw std::runtime_error(ss.str());
+ }
+ if (bias.defined() && input.type() != bias.type()){
+ std::stringstream ss;
+ ss << "Input type (" << input.toString() << ") and bias type (" << bias.toString() << ") should be the same";
+ throw std::runtime_error(ss.str());
+ }
+
+ output = at::mkldnn_convolution(input, weight, bias, params.padding, params.stride, params.dilation);
+#endif
} else {
if (params.groups == 1) {
output = at::_convolution_nogroup(
diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp
new file mode 100644
index 0000000000..25cddef9ae
--- /dev/null
+++ b/aten/src/ATen/native/mkldnn/Conv.cpp
@@ -0,0 +1,441 @@
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/Config.h>
+
+#if !AT_MKLDNN_ENABLED()
+
+namespace at { namespace native {
+
+at::Tensor mkldnn_convolution(
+ const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias,
+ IntList padding, IntList stride, IntList dilation) {
+ throw std::runtime_error("mkldnn_convolution_forward: ATen not compiled with MKLDNN support");
+}
+
+at::Tensor mkldnn_convolution_backward_input(
+ IntList input_size, const at::Tensor& grad_output, const at::Tensor& weight,
+ IntList padding, IntList stride, IntList dilation, bool bias_defined) {
+ throw std::runtime_error("mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support");
+}
+
+std::tuple<at::Tensor,at::Tensor> mkldnn_convolution_backward_weights(
+ IntList weight_size, const at::Tensor& grad_output, const at::Tensor& input,
+ IntList padding, IntList stride, IntList dilation, bool bias_defined) {
+ throw std::runtime_error("mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support");
+}
+
+std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward(
+ const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
+ IntList padding, IntList stride, IntList dilation, std::array<bool,3> output_mask) {
+ throw std::runtime_error("mkldnn_convolution_backward: ATen not compiled with MKLDNN support");
+}
+
+}}
+
+#else // AT_MKLDNN_EBABLED
+
+#include <ATen/mkldnn/Runtime.h>
+
+using namespace mkldnn;
+
+namespace at { namespace native {
+
+constexpr int input_batch_size_dim = 0; // also grad_input
+constexpr int input_channels_dim = 1;
+constexpr int output_batch_size_dim = 0; // also grad_output
+constexpr int output_channels_dim = 1;
+constexpr int weight_output_channels_dim = 0;
+constexpr int weight_input_channels_dim = 1;
+
+// Often written as 2 + max_dim (extra dims for batch size and channels)
+constexpr int max_dim = 3;
+
+std::vector<int64_t> conv_output_size(
+ IntList input_size, IntList weight_size,
+ IntList padding, IntList stride, IntList dilation)
+{
+ auto dim = input_size.size();
+ std::vector<int64_t> output_size(dim);
+ output_size[0] = input_size[input_batch_size_dim];
+ output_size[1] = weight_size[weight_output_channels_dim];
+ for (size_t d = 2; d < dim; ++d) {
+ auto kernel = dilation[d - 2] * (weight_size[d] - 1) + 1;
+ output_size[d] = (input_size[d] + (2 * padding[d - 2])
+ - kernel) / stride[d - 2] + 1;
+ }
+ return output_size;
+}
+
+at::Tensor mkldnn_convolution(
+ const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias,
+ IntList padding, IntList stride, IntList dilation)
+{
+ auto output = input.type().tensor(conv_output_size(
+ input.sizes(), weight.sizes(), padding, stride, dilation));
+
+ auto cpu_engine = CpuEngine::Instance().get_engine();
+
+ int32_t n = input.size(0);
+ int32_t ic = input.size(1);
+ int32_t ih = input.size(2);
+ int32_t iw = input.size(3);
+
+ int32_t oc = output.size(1);
+ int32_t oh = output.size(2);
+ int32_t ow = output.size(3);
+
+ int32_t kh = weight.size(2);
+ int32_t kw = weight.size(3);
+
+ int32_t sh = stride[0];
+ int32_t sw = stride[1];
+ int32_t ph = padding[0];
+ int32_t pw = padding[1];
+
+ auto data_t = memory::data_type::f32;
+ auto format_any = memory::format::any;
+ auto format_nchw = memory::format::nchw;
+ auto format_oihw = memory::format::oihw;
+ auto format_x = memory::format::x;
+
+ memory::dims input_tz = {n, ic, ih, iw};
+ memory::dims weight_tz = {oc, ic, kh, kw};
+ memory::dims bias_tz = {oc};
+ memory::dims output_tz = {n, oc, oh, ow};
+ memory::dims _stride = {sh, sw};
+ memory::dims _padding = {ph, pw};
+
+ auto input_md = memory::desc({input_tz}, data_t, format_any);
+ auto weight_md = memory::desc({weight_tz}, data_t, format_any);
+ auto bias_md = memory::desc({bias_tz}, data_t, format_any);
+ auto output_md = memory::desc({output_tz}, data_t, format_any);
+
+ std::shared_ptr<convolution_forward::desc> conv_forward_desc;
+ if (bias.defined()) {
+ conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
+ convolution_direct, input_md, weight_md, bias_md, output_md,
+ _stride, _padding, _padding, padding_kind::zero));
+ } else {
+ conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
+ convolution_direct, input_md, weight_md, output_md,
+ _stride, _padding, _padding, padding_kind::zero));
+ }
+
+ std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
+ conv_forward_pd.reset(new convolution_forward::primitive_desc(
+ *conv_forward_desc, cpu_engine));
+
+ auto input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
+ input.data_ptr());
+ auto weight_usr_memory = memory({{{weight_tz}, data_t, format_oihw}, cpu_engine},
+ weight.data_ptr());
+ auto output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
+ output.data_ptr());
+
+ std::vector<primitive> net;
+
+ auto input_pd = conv_forward_pd->src_primitive_desc();
+ auto input_memory = input_usr_memory;
+ if (input_usr_memory.get_primitive_desc() != memory::primitive_desc(input_pd)) {
+ input_memory = memory(input_pd);
+ net.push_back(reorder(input_usr_memory, input_memory));
+ }
+
+ auto weight_pd = conv_forward_pd->weights_primitive_desc();
+ auto weight_memory = weight_usr_memory;
+ if (weight_usr_memory.get_primitive_desc() != memory::primitive_desc(weight_pd)) {
+ weight_memory = memory(weight_pd);
+ net.push_back(reorder(weight_usr_memory, weight_memory));
+ }
+
+ auto output_pd = conv_forward_pd->dst_primitive_desc();
+ auto output_memory = output_usr_memory;
+ if (output_usr_memory.get_primitive_desc() != memory::primitive_desc(output_pd)) {
+ output_memory = memory(output_pd);
+ }
+
+ std::shared_ptr<convolution_forward> conv_forward;
+ std::shared_ptr<memory> bias_usr_memory;
+ if (bias.defined()) {
+ bias_usr_memory.reset(new memory({{{bias_tz}, data_t, format_x}, cpu_engine},
+ bias.data_ptr()));
+ conv_forward.reset(new convolution_forward(*conv_forward_pd, input_memory,
+ weight_memory, *bias_usr_memory, output_memory));
+ } else {
+ conv_forward.reset(new convolution_forward(*conv_forward_pd, input_memory,
+ weight_memory, output_memory));
+ }
+ net.push_back(*conv_forward);
+
+ if (output_memory != output_usr_memory) {
+ net.push_back(reorder(output_memory, output_usr_memory));
+ }
+
+ Stream::Instance().get_stream().submit(net);
+
+ return output;
+}
+
+Tensor mkldnn_convolution_backward_input(
+ IntList input_size, const at::Tensor& grad_output, const at::Tensor& weight,
+ IntList padding, IntList stride, IntList dilation, bool bias_defined)
+{
+ auto grad_input = grad_output.type().tensor(input_size);
+
+ auto cpu_engine = CpuEngine::Instance().get_engine();
+
+ int32_t n = grad_input.size(0);
+ int32_t ic = grad_input.size(1);
+ int32_t ih = grad_input.size(2);
+ int32_t iw = grad_input.size(3);
+
+ int32_t oc = grad_output.size(1);
+ int32_t oh = grad_output.size(2);
+ int32_t ow = grad_output.size(3);
+
+ int32_t kh = weight.size(2);
+ int32_t kw = weight.size(3);
+
+ int32_t sh = stride[0];
+ int32_t sw = stride[1];
+ int32_t ph = padding[0];
+ int32_t pw = padding[1];
+
+ auto data_t = memory::data_type::f32;
+ auto format_any = memory::format::any;
+ auto format_nchw = memory::format::nchw;
+ auto format_oihw = memory::format::oihw;
+
+ memory::dims input_tz = {n, ic, ih, iw};
+ memory::dims weight_tz = {oc, ic, kh, kw};
+ memory::dims bias_tz = {oc};
+ memory::dims output_tz = {n, oc, oh, ow};
+ memory::dims _stride = {sh, sw};
+ memory::dims _padding = {ph, pw};
+
+ auto input_md = memory::desc({input_tz}, data_t, format_any);
+ auto weight_md = memory::desc({weight_tz}, data_t, format_any);
+ auto bias_md = memory::desc({bias_tz}, data_t, format_any);
+ auto output_md = memory::desc({output_tz}, data_t, format_any);
+
+ // need to re-create conv_forward_pd to feed conv_backward_data_pd
+ std::shared_ptr<convolution_forward::desc> conv_forward_desc;
+ if (bias_defined) {
+ conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
+ convolution_direct, input_md, weight_md, bias_md, output_md,
+ _stride, _padding, _padding, padding_kind::zero));
+ } else {
+ conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
+ convolution_direct, input_md, weight_md, output_md,
+ _stride, _padding, _padding, padding_kind::zero));
+ }
+
+ std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
+ conv_forward_pd.reset(new convolution_forward::primitive_desc(
+ *conv_forward_desc, cpu_engine));
+
+ std::shared_ptr<convolution_backward_data::desc> conv_backward_data_desc;
+ conv_backward_data_desc.reset(new convolution_backward_data::desc(
+ convolution_direct, input_md, weight_md, output_md,
+ _stride, _padding, _padding, padding_kind::zero));
+
+ std::shared_ptr<convolution_backward_data::primitive_desc> conv_backward_data_pd;
+ conv_backward_data_pd.reset(new convolution_backward_data::primitive_desc(
+ *conv_backward_data_desc, cpu_engine, *conv_forward_pd));
+
+ auto grad_output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
+ grad_output.data_ptr());
+ auto weight_usr_memory = memory({{{weight_tz}, data_t, format_oihw}, cpu_engine},
+ weight.data_ptr());
+ auto grad_input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
+ grad_input.data_ptr());
+
+ std::vector<primitive> net;
+
+ auto grad_output_pd = conv_backward_data_pd->diff_dst_primitive_desc();
+ auto grad_output_memory = grad_output_usr_memory;
+ if (grad_output_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_output_pd)) {
+ grad_output_memory = memory(grad_output_pd);
+ net.push_back(reorder(grad_output_usr_memory, grad_output_memory));
+ }
+
+ auto weight_pd = conv_backward_data_pd->weights_primitive_desc();
+ auto weight_memory = weight_usr_memory;
+ if (weight_usr_memory.get_primitive_desc() != memory::primitive_desc(weight_pd)) {
+ weight_memory = memory(weight_pd);
+ net.push_back(reorder(weight_usr_memory, weight_memory));
+ }
+
+ auto grad_input_pd = conv_backward_data_pd->diff_src_primitive_desc();
+ auto grad_input_memory = grad_input_usr_memory;
+ if (grad_input_memory.get_primitive_desc() != memory::primitive_desc(grad_input_pd)) {
+ grad_input_memory = memory(grad_input_pd);
+ }
+
+ std::shared_ptr<convolution_backward_data> conv_backward_data;
+ conv_backward_data.reset(new convolution_backward_data(*conv_backward_data_pd,
+ grad_output_memory, weight_memory, grad_input_memory));
+ net.push_back(*conv_backward_data);
+
+ if (grad_input_memory != grad_input_usr_memory) {
+ net.push_back(reorder(grad_input_memory, grad_input_usr_memory));
+ }
+
+ Stream::Instance().get_stream().submit(net);
+
+ return grad_input;
+}
+
+std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
+ IntList weight_size, const at::Tensor& grad_output, const at::Tensor& input,
+ IntList padding, IntList stride, IntList dilation, bool bias_defined)
+{
+ auto grad_weight = grad_output.type().tensor(weight_size);
+
+ Tensor grad_bias;
+ if (bias_defined) {
+ grad_bias = grad_output.type().tensor({grad_output.size(1)});
+ }
+
+ auto cpu_engine = CpuEngine::Instance().get_engine();
+
+ int32_t n = input.size(0);
+ int32_t ic = input.size(1);
+ int32_t ih = input.size(2);
+ int32_t iw = input.size(3);
+
+ int32_t oc = grad_output.size(1);
+ int32_t oh = grad_output.size(2);
+ int32_t ow = grad_output.size(3);
+
+ int32_t kh = grad_weight.size(2);
+ int32_t kw = grad_weight.size(3);
+
+ int32_t sh = stride[0];
+ int32_t sw = stride[1];
+ int32_t ph = padding[0];
+ int32_t pw = padding[1];
+
+ auto data_t = memory::data_type::f32;
+ auto format_any = memory::format::any;
+ auto format_nchw = memory::format::nchw;
+ auto format_oihw = memory::format::oihw;
+ auto format_x = memory::format::x;
+
+ memory::dims input_tz = {n, ic, ih, iw};
+ memory::dims weight_tz = {oc, ic, kh, kw};
+ memory::dims bias_tz = {oc};
+ memory::dims output_tz = {n, oc, oh, ow};
+ memory::dims _stride = {sh, sw};
+ memory::dims _padding = {ph, pw};
+
+ memory::desc input_md({input_tz}, data_t, format_any);
+ memory::desc weight_md({weight_tz}, data_t, format_any);
+ memory::desc bias_md({bias_tz}, data_t, format_any);
+ memory::desc output_md({output_tz}, data_t, format_any);
+
+ // need to re-create conv_forward_pd to feed conv_backward_weight_pd
+ std::shared_ptr<convolution_forward::desc> conv_forward_desc;
+ if (bias_defined) {
+ conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
+ convolution_direct, input_md, weight_md, bias_md, output_md,
+ _stride, _padding, _padding, padding_kind::zero));
+ } else {
+ conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
+ convolution_direct, input_md, weight_md, output_md,
+ _stride, _padding, _padding, padding_kind::zero));
+ }
+
+ std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
+ conv_forward_pd.reset(new convolution_forward::primitive_desc(
+ *conv_forward_desc, cpu_engine));
+
+ std::shared_ptr<convolution_backward_weights::desc> conv_backward_weight_desc;
+ if (bias_defined) {
+ conv_backward_weight_desc.reset(new convolution_backward_weights::desc(
+ convolution_direct, input_md, weight_md, bias_md, output_md,
+ _stride, _padding, _padding, padding_kind::zero));
+ } else {
+ conv_backward_weight_desc.reset(new convolution_backward_weights::desc(
+ convolution_direct, input_md, weight_md, output_md,
+ _stride, _padding, _padding, padding_kind::zero));
+ }
+
+ std::shared_ptr<convolution_backward_weights::primitive_desc> conv_backward_weight_pd;
+ conv_backward_weight_pd.reset(new convolution_backward_weights::primitive_desc(
+ *conv_backward_weight_desc, cpu_engine, *conv_forward_pd));
+
+ auto input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
+ input.data_ptr());
+ auto grad_output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
+ grad_output.data_ptr());
+ auto grad_weight_usr_memory = memory({{{weight_tz}, data_t, format_oihw}, cpu_engine},
+ grad_weight.data_ptr());
+ std::shared_ptr<memory> grad_bias_memory;
+
+ std::vector<primitive> net;
+
+ auto input_pd = conv_backward_weight_pd->src_primitive_desc();
+ auto input_memory = input_usr_memory;
+ if (input_usr_memory.get_primitive_desc() != memory::primitive_desc(input_pd)) {
+ input_memory = memory(input_pd);
+ net.push_back(reorder(input_usr_memory, input_memory));
+ }
+
+ auto grad_output_pd = conv_backward_weight_pd->diff_dst_primitive_desc();
+ auto grad_output_memory = grad_output_usr_memory;
+ if (grad_output_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_output_pd)) {
+ grad_output_memory = memory(grad_output_pd);
+ net.push_back(reorder(grad_output_usr_memory, grad_output_memory));
+ }
+
+ auto grad_weight_pd = conv_backward_weight_pd->diff_weights_primitive_desc();
+ auto grad_weight_memory = grad_weight_usr_memory;
+ if (grad_weight_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_weight_pd)) {
+ grad_weight_memory = memory(grad_weight_pd);
+ }
+
+ std::shared_ptr<convolution_backward_weights> conv_backward_weight;
+ if (bias_defined) {
+ grad_bias_memory.reset(new memory({{{bias_tz}, data_t, format_x}, cpu_engine},
+ grad_bias.data_ptr()));
+ conv_backward_weight.reset(new convolution_backward_weights(*conv_backward_weight_pd,
+ input_memory, grad_output_memory, grad_weight_memory, *grad_bias_memory));
+ } else {
+ conv_backward_weight.reset(new convolution_backward_weights(*conv_backward_weight_pd,
+ input_memory, grad_output_memory, grad_weight_memory));
+ }
+
+ net.push_back(*conv_backward_weight);
+
+ if (grad_weight_memory != grad_weight_usr_memory) {
+ net.push_back(reorder(grad_weight_memory, grad_weight_usr_memory));
+ }
+
+ Stream::Instance().get_stream().submit(net);
+
+ return std::tuple<at::Tensor, at::Tensor>{grad_weight, grad_bias};
+}
+
+std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward(
+ const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
+ IntList padding, IntList stride, IntList dilation, std::array<bool,3> output_mask)
+{
+ Tensor grad_output = grad_output_t.contiguous();
+
+ Tensor grad_input, grad_weight, grad_bias;
+ if (output_mask[0]) {
+ grad_input = at::mkldnn_convolution_backward_input(
+ input.sizes(), grad_output, weight, padding, stride, dilation, output_mask[2]);
+ }
+ if (output_mask[1] || output_mask[2]) {
+ std::tie(grad_weight, grad_bias) = at::mkldnn_convolution_backward_weights(
+ weight.sizes(), grad_output, input, padding, stride, dilation, output_mask[2]);
+ }
+
+ return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
+}
+
+}} // namespace at::native
+
+#endif
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 7e665909d7..93fa94eac3 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -711,3 +711,15 @@
dispatch:
CPU: _s_poisson_cpu
CUDA: _s_poisson_cuda
+
+- func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, IntList padding, IntList stride, IntList dilation) -> Tensor
+ variants: function
+
+- func: mkldnn_convolution_backward_input(IntList self_size, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, bool bias_defined) -> Tensor
+ variants: function
+
+- func: mkldnn_convolution_backward_weights(IntList weight_size, Tensor grad_output, Tensor self, IntList padding, IntList stride, IntList dilation, bool bias_defined) -> (Tensor, Tensor)
+ variants: function
+
+- func: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, std::array<bool,3> output_mask) -> (Tensor, Tensor, Tensor)
+ variants: function