diff options
-rw-r--r-- | aten/CMakeLists.txt | 14 | ||||
-rw-r--r-- | aten/cmake/FindNNPACK.cmake | 57 | ||||
-rw-r--r-- | aten/doc/Functions.h | 16 | ||||
-rw-r--r-- | aten/doc/Type.h | 4 | ||||
-rw-r--r-- | aten/src/ATen/CMakeLists.txt | 4 | ||||
-rw-r--r-- | aten/src/ATen/Config.h.in | 1 | ||||
-rw-r--r-- | aten/src/ATen/native/Convolution.cpp | 36 | ||||
-rw-r--r-- | aten/src/ATen/native/NNPACK.cpp | 580 | ||||
-rw-r--r-- | aten/src/ATen/native/native_functions.yaml | 14 | ||||
-rw-r--r-- | tools/autograd/derivatives.yaml | 7 |
10 files changed, 5 insertions, 728 deletions
diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt index 13e4abca94..aba64bc623 100644 --- a/aten/CMakeLists.txt +++ b/aten/CMakeLists.txt @@ -474,20 +474,6 @@ else() endif() endif() -if(NO_NNPACK) - message("disabling NNPACK because NO_NNPACK is set") - set(AT_NNPACK_ENABLED 0) -else() - find_package(NNPACK) - if(NOT NNPACK_FOUND) - MESSAGE(STATUS "NNPACK not found. Compiling without NNPACK support") - set(AT_NNPACK_ENABLED 0) - ELSE() - INCLUDE_DIRECTORIES(${NNPACK_INCLUDE_DIRS}) - set(AT_NNPACK_ENABLED 1) - ENDIF() -endif() - set(cwrap_files ${CMAKE_CURRENT_SOURCE_DIR}/src/ATen/Declarations.cwrap ${CMAKE_CURRENT_SOURCE_DIR}/src/THNN/generic/THNN.h diff --git a/aten/cmake/FindNNPACK.cmake b/aten/cmake/FindNNPACK.cmake deleted file mode 100644 index e97c5aca3d..0000000000 --- a/aten/cmake/FindNNPACK.cmake +++ /dev/null @@ -1,57 +0,0 @@ -# Cribbed from https://github.com/caffe2/caffe2/blob/master/cmake/Modules/FindNNPACK.cmake -# -# - Try to find NNPACK -# -# The following variables are optionally searched for defaults -# NNPACK_ROOT_DIR: Base directory where all NNPACK components are found -# -# The following are set after configuration is done: -# NNPACK_FOUND -# NNPACK_INCLUDE_DIRS -# NNPACK_LIBRARIES -# NNPACK_LIBRARYRARY_DIRS - -include(FindPackageHandleStandardArgs) -include(CheckSymbolExists) - -set(NNPACK_ROOT_DIR "" CACHE PATH "Folder contains NNPACK") - -find_path(NNPACK_INCLUDE_DIR nnpack.h - PATHS ${NNPACK_ROOT_DIR} - PATH_SUFFIXES include) - -# TODO: deps/pthreadpool/include may also need to be registered as an include directory -# TODO: Conda searching? - -find_library(NNPACK_LIBRARY nnpack - PATHS ${NNPACK_ROOT_DIR} - PATH_SUFFIXES lib lib64) - -find_library(CPUINFO_LIBRARY cpuinfo - PATHS ${NNPACK_ROOT_DIR} - PATH_SUFFIXES lib lib64) - -find_library(PTHREADPOOL_LIBRARY pthreadpool - PATHS ${NNPACK_ROOT_DIR} - PATH_SUFFIXES lib lib64) - -find_package_handle_standard_args(NNPACK DEFAULT_MSG NNPACK_INCLUDE_DIR NNPACK_LIBRARY CPUINFO_LIBRARY PTHREADPOOL_LIBRARY) - -if(NNPACK_FOUND) - set(NNPACK_INCLUDE_DIRS ${NNPACK_INCLUDE_DIR}) - set(NNPACK_LIBRARIES ${NNPACK_LIBRARY} ${CPUINFO_LIBRARY} ${PTHREADPOOL_LIBRARY}) - - list(APPEND CMAKE_REQUIRED_LIBRARIES ${NNPACK_LIBRARIES}) - list(APPEND CMAKE_REQUIRED_INCLUDES ${NNPACK_INCLUDE_DIRS}) - check_symbol_exists(nnp_convolution_kernel_gradient "nnpack.h" NNPACK_HAS_INFERENCE) - - if(NNPACK_HAS_INFERENCE) - message(STATUS "Found NNPACK (include: ${NNPACK_INCLUDE_DIR}, library: ${NNPACK_LIBRARY})") - message(STATUS "Found CPUINFO (library: ${CPUINFO_LIBRARY})") - message(STATUS "Found PTHREADPOOL (library: ${PTHREADPOOL_LIBRARY})") - mark_as_advanced(NNPACK_ROOT_DIR NNPACK_LIBRARY_RELEASE NNPACK_LIBRARY_DEBUG - NNPACK_LIBRARY NNPACK_INCLUDE_DIR) - else() - message(STATUS "Refusing to use incomplete NNPACK (include: ${NNPACK_INCLUDE_DIR}, library: ${NNPACK_LIBRARY}); try reinstalling NNPACK without --inference-only") - endif() -endif() diff --git a/aten/doc/Functions.h b/aten/doc/Functions.h index c611849e0b..0237ba674a 100644 --- a/aten/doc/Functions.h +++ b/aten/doc/Functions.h @@ -755,10 +755,6 @@ static inline Tensor & mm_out(Tensor & result, const Tensor & self, const Tensor static inline Tensor mv(const Tensor & self, const Tensor & vec); static inline Tensor & mv_out(Tensor & result, const Tensor & self, const Tensor & vec); static inline Tensor narrow(const Tensor & self, int64_t dim, int64_t start, int64_t length); -static inline Tensor nnpack_spatial_convolution(const Tensor & input, const Tensor & weight, const Tensor & bias, int64_t kW, int64_t kH, int64_t padW, int64_t padH); -static inline std::tuple<Tensor,Tensor,Tensor> nnpack_spatial_convolution_backward(const Tensor & input, const Tensor & grad_output, const Tensor & weight, int64_t kW, int64_t kH, int64_t padW, int64_t padH, std::array<bool,3> output_mask); -static inline Tensor nnpack_spatial_convolution_backward_input(const Tensor & input, const Tensor & grad_output, const Tensor & weight, int64_t kW, int64_t kH, int64_t padW, int64_t padH); -static inline Tensor nnpack_spatial_convolution_backward_weight(const Tensor & input, IntList weight_size, const Tensor & grad_output, int64_t kW, int64_t kH, int64_t padW, int64_t padH); static inline Tensor pin_memory(const Tensor & self); static inline Tensor rand_like(const Tensor & self); static inline Tensor randn_like(const Tensor & self); @@ -3044,18 +3040,6 @@ static inline Tensor & mv_out(Tensor & result, const Tensor & self, const Tensor static inline Tensor narrow(const Tensor & self, int64_t dim, int64_t start, int64_t length) { return infer_type(self).narrow(self, dim, start, length); } -static inline Tensor nnpack_spatial_convolution(const Tensor & input, const Tensor & weight, const Tensor & bias, int64_t kW, int64_t kH, int64_t padW, int64_t padH) { - return infer_type(input).nnpack_spatial_convolution(input, weight, bias, kW, kH, padW, padH); -} -static inline std::tuple<Tensor,Tensor,Tensor> nnpack_spatial_convolution_backward(const Tensor & input, const Tensor & grad_output, const Tensor & weight, int64_t kW, int64_t kH, int64_t padW, int64_t padH, std::array<bool,3> output_mask) { - return infer_type(input).nnpack_spatial_convolution_backward(input, grad_output, weight, kW, kH, padW, padH, output_mask); -} -static inline Tensor nnpack_spatial_convolution_backward_input(const Tensor & input, const Tensor & grad_output, const Tensor & weight, int64_t kW, int64_t kH, int64_t padW, int64_t padH) { - return infer_type(input).nnpack_spatial_convolution_backward_input(input, grad_output, weight, kW, kH, padW, padH); -} -static inline Tensor nnpack_spatial_convolution_backward_weight(const Tensor & input, IntList weight_size, const Tensor & grad_output, int64_t kW, int64_t kH, int64_t padW, int64_t padH) { - return infer_type(input).nnpack_spatial_convolution_backward_weight(input, weight_size, grad_output, kW, kH, padW, padH); -} static inline Tensor pin_memory(const Tensor & self) { return infer_type(self).pin_memory(self); } diff --git a/aten/doc/Type.h b/aten/doc/Type.h index 849e3cbefb..90d7ce2903 100644 --- a/aten/doc/Type.h +++ b/aten/doc/Type.h @@ -1093,10 +1093,6 @@ struct AT_API Type { virtual Tensor mv(const Tensor & self, const Tensor & vec) const; virtual Tensor & mv_out(Tensor & result, const Tensor & self, const Tensor & vec) const; virtual Tensor narrow(const Tensor & self, int64_t dim, int64_t start, int64_t length) const; - virtual Tensor nnpack_spatial_convolution(const Tensor & input, const Tensor & weight, const Tensor & bias, int64_t kW, int64_t kH, int64_t padW, int64_t padH) const; - virtual std::tuple<Tensor,Tensor,Tensor> nnpack_spatial_convolution_backward(const Tensor & input, const Tensor & grad_output, const Tensor & weight, int64_t kW, int64_t kH, int64_t padW, int64_t padH, std::array<bool,3> output_mask) const; - virtual Tensor nnpack_spatial_convolution_backward_input(const Tensor & input, const Tensor & grad_output, const Tensor & weight, int64_t kW, int64_t kH, int64_t padW, int64_t padH) const; - virtual Tensor nnpack_spatial_convolution_backward_weight(const Tensor & input, IntList weight_size, const Tensor & grad_output, int64_t kW, int64_t kH, int64_t padW, int64_t padH) const; virtual Tensor permute(const Tensor & self, IntList dims) const; virtual Tensor pin_memory(const Tensor & self) const; virtual Tensor rand_like(const Tensor & self) const; diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 27f1d98461..c625afc957 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -317,10 +317,6 @@ ELSE(NOT C_HAS_THREAD) SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DTH_HAVE_THREAD") ENDIF(NOT C_HAS_THREAD) -if (NNPACK_FOUND) - target_link_libraries(ATen ${NNPACK_LIBRARIES}) -endif(NNPACK_FOUND) - if(MKLDNN_FOUND) target_link_libraries(ATen ${MKLDNN_LIBRARIES}) endif(MKLDNN_FOUND) diff --git a/aten/src/ATen/Config.h.in b/aten/src/ATen/Config.h.in index 1ab0ec9162..76130f7cc0 100644 --- a/aten/src/ATen/Config.h.in +++ b/aten/src/ATen/Config.h.in @@ -7,7 +7,6 @@ #define AT_CUDA_ENABLED() @AT_CUDA_ENABLED@ #define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@ #define AT_MKLDNN_ENABLED() @AT_MKLDNN_ENABLED@ -#define AT_NNPACK_ENABLED() @AT_NNPACK_ENABLED@ #define AT_MKL_ENABLED() @AT_MKL_ENABLED@ #if !AT_CUDA_ENABLED() && AT_CUDNN_ENABLED() diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 33f7e05d26..dba3988cd7 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -7,10 +7,6 @@ #include "ATen/cudnn/cudnn-wrapper.h" #endif -#if AT_NNPACK_ENABLED() -#include "nnpack.h" -#endif - namespace at { namespace native { struct ConvParams { @@ -33,7 +29,6 @@ struct ConvParams { void view1d_as_2d(); bool use_cudnn(const at::Tensor& input) const; bool use_mkldnn(const at::Tensor& input) const; - bool use_nnpack(const at::Tensor& input) const; bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const; }; @@ -142,18 +137,6 @@ auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool { #endif return false; } -auto ConvParams::use_nnpack(const at::Tensor& input) const -> bool { -#if AT_NNPACK_ENABLED() - return input.type().backend() == kCPU && - input.type().scalarType() == kFloat && // only on CPU Float Tensors - !is_strided() && // doesn't support strides - !is_dilated() && // or dilation - !transposed && // or transposed tensors - input.ndimension() == 4 && // must be in NCHW format - input.size(0) >= 16; // ensure large enough batch size to ensure perf, tuneable -#endif - return false; -} // We currently only have depthwise support for the case where groups == // nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of @@ -461,20 +444,11 @@ at::Tensor _convolution_nogroup( input, weight, kernel_size, bias, stride, padding, dilation); } else { /* dim == 4, non-dilated */ - if (params.use_nnpack(input)) { -#if AT_NNPACK_ENABLED() - return at::nnpack_spatial_convolution( - input, weight, bias, - kernel_size[1], kernel_size[0], - params.padding[1], params.padding[0]); -#endif - } else { - /* CPU implementation has specialized MM kernels - for non-dilated case here */ - return at::thnn_conv2d( - input, weight, kernel_size, bias, - stride, padding); - } + /* CPU implementation has specialized MM kernels + for non-dilated case here */ + return at::thnn_conv2d( + input, weight, kernel_size, bias, + stride, padding); } } else if (dim == 5 && (input.type().is_cuda() || dilated)) { return at::thnn_conv_dilated3d( diff --git a/aten/src/ATen/native/NNPACK.cpp b/aten/src/ATen/native/NNPACK.cpp deleted file mode 100644 index 7aeade21fa..0000000000 --- a/aten/src/ATen/native/NNPACK.cpp +++ /dev/null @@ -1,580 +0,0 @@ -#include <ATen/Config.h> -#include <ATen/ATen.h> - -#if !AT_NNPACK_ENABLED() - -namespace at { namespace native { - -at::Tensor nnpack_spatial_convolution( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& bias, - int64_t kW, - int64_t kH, - int64_t padW, - int64_t padH) { - throw std::runtime_error("nnpack_spatial_convolution: ATen not compiled with NNPACK support"); -} - -at::Tensor nnpack_spatial_convolution_backward_input( - const at::Tensor& input, - const at::Tensor& gradOutput, - const at::Tensor& weight, - int64_t kW, - int64_t kH, - int64_t padW, - int64_t padH) { - throw std::runtime_error("nnpack_spatial_convolution_backward_input: ATen not compiled with NNPACK support"); -} - -at::Tensor nnpack_spatial_convolution_backward_weight( - const at::Tensor& input, - at::IntList weight_size, - const at::Tensor& gradOutput, - int64_t kW, - int64_t kH, - int64_t padW, - int64_t padH) { - throw std::runtime_error("nnpack_spatial_convolution_backward_weight: ATen not compiled with NNPACK support"); -} - -std::tuple<at::Tensor,at::Tensor,at::Tensor> nnpack_spatial_convolution_backward( - const at::Tensor& input, - const at::Tensor& gradOutput, - const at::Tensor& weight, - int64_t kW, - int64_t kH, - int64_t padW, - int64_t padH, - std::array<bool,3> output_mask) { - throw std::runtime_error("nnpack_spatial_convolution_backward: ATen not compiled with NNPACK support"); -} - -}} // at::native - -#else - -#include "nnpack.h" - -#include <stdlib.h> - -#ifdef _OPENMP -#include <omp.h> -#else -#include <thread> -#endif - -namespace at { -namespace native { - -// Stolen from Caffe2 -static pthreadpool_t nnpack_threadpool_ = nullptr; - -pthreadpool_t nnpack_threadpool() { - if (nnpack_threadpool_ == nullptr) { - enum nnp_status nnpack_status = nnp_initialize(); - if (nnpack_status != nnp_status_success) throw std::runtime_error("could not initialize NNPack"); - unsigned int threads; -#ifdef _OPENMP - threads = omp_get_num_threads(); -#else - threads = std::thread::hardware_concurrency(); -#endif - nnpack_threadpool_ = pthreadpool_create(threads); - if (nnpack_threadpool_ == nullptr) { - throw std::runtime_error("could not initialize NNPack's pthreadpool"); - } - } - return nnpack_threadpool_; -} - -// Make thread_local for safety in cases where we have multiple threads running Convs at once -static thread_local void *workspace = nullptr; -static thread_local size_t workspace_size = 0; - -// NNPack has alignment requirements -const size_t nnpack_memory_alignment_boundary = 64; - -static inline void deallocate_workspace() { - if (workspace) - std::free(workspace); - workspace = nullptr; -} - -static inline void allocate_workspace() { - if (workspace) - deallocate_workspace(); - // Won't work on Windows, but NNPACK doesn't support Windows either - posix_memalign(&workspace, nnpack_memory_alignment_boundary, workspace_size); -} - -constexpr int input_batch_size_dim = 0; -constexpr int input_channels_dim = 1; -constexpr int input_height_dim = 2; -constexpr int input_width_dim = 3; -constexpr int output_batch_size_dim = 0; -constexpr int output_channels_dim = 1; -constexpr int output_height_dim = 2; -constexpr int output_width_dim = 3; -constexpr int weight_output_channels_dim = 0; -constexpr int weight_input_channels_dim = 1; -constexpr int weight_height_dim = 2; -constexpr int weight_width_dim = 3; - -// Often written as 2 + max_dim (extra dims for batch size and channels) -constexpr int max_dim = 3; - -std::vector<int64_t> conv_output_size( - IntList input_size, IntList weight_size, - int64_t kW, int64_t kH, int64_t padW, int64_t padH -) { - auto dim = input_size.size(); - std::vector<int64_t> output_size(dim); - output_size[output_batch_size_dim] = input_size[input_batch_size_dim]; - output_size[output_channels_dim] = weight_size[weight_output_channels_dim]; - output_size[output_height_dim] = input_size[input_height_dim] + 2 * padH - (kH - 1); - output_size[output_width_dim] = input_size[input_width_dim] + 2 * padW - (kW - 1); - return output_size; -} - -Tensor nnpack_spatial_convolution( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& bias, - int64_t kW, - int64_t kH, - int64_t padW, - int64_t padH) { - - at::Tensor output = input.type().tensor(conv_output_size(input.sizes(), weight.sizes(), kW, kH, padW, padH)); - - // Our input Tensor must be in the form N,C,H,W - if (input.ndimension() != 4) { - throw std::runtime_error("NNPack convolutionOutput expects 4D input Tensor N,C,H,W"); - } - // Our weight Tensor must be in the form oC,iC,kH,kW - if (weight.ndimension() != 4) { - throw std::runtime_error("NNPack convolutionOutput expects 4D weight Tensor oC,iC,kH,kW"); - } - // Our output Tensor must be in the form N,oC,oH,oW - if (output.ndimension() != 4) { - throw std::runtime_error("NNPack convolutionOutput expects 4D output Tensor N,oC,oH,oW"); - } - - // Some basic shape checking, not comprehensive - if (input.size(1) != weight.size(1)) { - std::stringstream err; - err << "Mismatch between number of input channels in input Tensor (" << input.size(1) - << ") and weight Tensor (" << weight.size(1) << ") in NNPack convolutionOutput"; - throw std::runtime_error(err.str()); - } - if (weight.size(0) != output.size(1)) { - std::stringstream err; - err << "Mismatch between number of output channels in weight Tensor (" << weight.size(0) - << ") and output Tensor (" << output.size(1) << ") in NNPack convolutionOutput"; - throw std::runtime_error(err.str()); - } - if (input.size(0) != output.size(0)) { - std::stringstream err; - err << "Mismatch between batch size in input Tensor (" << input.size(0) - << ") and output Tensor (" << output.size(0) << ") in NNPack convolutionOutput"; - throw std::runtime_error(err.str()); - } - - // Setup parameters for the NNPack convolution output function call - - // For now, we use the default algorithm - auto algorithm = nnp_convolution_algorithm_auto; - - // All Tensors must be float Tensors - if (input.type().ID() != at::TypeID::CPUFloat || - weight.type().ID() != at::TypeID::CPUFloat || - output.type().ID() != at::TypeID::CPUFloat || - (bias.defined() && bias.type().ID() != at::TypeID::CPUFloat)) { - throw std::runtime_error("Mismatched Tensor types in NNPack convolutionOutput"); - } - - const size_t batch_size = input.size(0); - const size_t input_channels = input.size(1); - const size_t output_channels = weight.size(0); - const struct nnp_size input_size = { - .width = (size_t)input.size(3), - .height = (size_t)input.size(2) - }; - const struct nnp_padding input_padding = { - .top = (size_t)padH, - .right = (size_t)padW, - .bottom = (size_t)padH, - .left = (size_t)padW - }; - const struct nnp_size kernel_size = { - .width = (size_t)kW, - .height = (size_t)kH - }; - - // If we don't have a defined bias Tensor, we need to create one filled with zeroes - auto bias_ = bias.defined() ? bias : input.type().zeros({weight.size(0)}); - - // Note: we assume that the output is shaped correctly, probably should add an assert - - auto batched = [&]() -> nnp_status { - return nnp_convolution_output( - algorithm, - batch_size, - input_channels, - output_channels, - input_size, - input_padding, - kernel_size, - (float*)input.data_ptr(), - (float*)weight.data_ptr(), - (float*)bias_.data_ptr(), - (float*)output.data_ptr(), - workspace, // workspace_buffer - &workspace_size, // workspace_size - nnp_activation_identity, - nullptr, // activation parameters - nnpack_threadpool(), - nullptr // profile - ); - }; - - auto single = [&]() -> nnp_status { - const nnp_size output_subsample = { - .width = 1, - .height = 1 - }; - return nnp_convolution_inference( - algorithm, - nnp_convolution_transform_strategy_compute, - input_channels, - output_channels, - input_size, - input_padding, - kernel_size, - output_subsample, - (float*)input.data_ptr(), - (float*)weight.data_ptr(), - (float*)bias_.data_ptr(), - (float*)output.data_ptr(), - workspace, // workspace_buffer - &workspace_size, // workspace_size - nnp_activation_identity, - nullptr, // activation parameters - nnpack_threadpool(), - nullptr // profile - ); - }; - - auto size_and_allocate_ws = [&]() { - // Run a single pass to get the size of memory workspace buffer - auto status = batch_size == 1 ? single() : batched(); - if (status != nnp_status_success) { - throw std::runtime_error("NNPACK SpatialConvolution_updateOutput failed"); - } - allocate_workspace(); - }; - - // If no workspace created yet, allocate it - if (workspace == nullptr) { - size_and_allocate_ws(); - } - - // Try to run with the newly created, or existing workspace - auto status = batch_size == 1 ? single() : batched(); - - if (status == nnp_status_insufficient_buffer) { - // Need to reallocate the workspace - deallocate_workspace(); - size_and_allocate_ws(); - - // Try one more time - status = batch_size == 1 ? single() : batched(); - } - - if (status != nnp_status_success) { - throw std::runtime_error("NNPACK SpatialConvolution_updateOutput failed"); - } - - return output; -} - -Tensor nnpack_spatial_convolution_backward_input( - const at::Tensor& input, - const at::Tensor& gradOutput, - const at::Tensor& weight, - int64_t kW, - int64_t kH, - int64_t padW, - int64_t padH) { - - at::Tensor gradInput = input.type().tensor(input.sizes()); - - // Our input and gradInput Tensors must be in the form N,C,H,W - if (input.ndimension() != 4) { - throw std::runtime_error("NNPack convolution updateGradInput expects 4D input Tensor N,C,H,W"); - } - if (gradInput.ndimension() != 4) { - throw std::runtime_error("NNPack convolution updateGradInput expects 4D gradInput Tensor N,C,H,W"); - } - // Our weight Tensor must be in the form oC,iC,kH,kW - if (weight.ndimension() != 4) { - throw std::runtime_error("NNPack convolution updateGradInput expects 4D weight Tensor oC,iC,kH,kW"); - } - // Our gradOutput Tensor must be in the form N,oC,oH,oW - if (gradOutput.ndimension() != 4) { - throw std::runtime_error("NNPack convolution updateGradInput expects 4D gradOutput Tensor N,oC,oH,oW"); - } - - // Some basic shape checking, not comprehensive - if (!input.sizes().equals(gradInput.sizes())) { - std::stringstream err; - err << "Mismatch between input size (" << input.sizes() << ") and gradInput size (" - << gradInput.sizes() << ") in NNPack convolution updateGradInput"; - throw std::runtime_error(err.str()); - } - if (input.size(1) != weight.size(1)) { - std::stringstream err; - err << "Mismatch between number of input channels in input Tensor (" << input.size(1) - << ") and weight Tensor (" << weight.size(1) << ") in NNPack convolution updateGradInput"; - throw std::runtime_error(err.str()); - } - if (weight.size(0) != gradOutput.size(1)) { - std::stringstream err; - err << "Mismatch between number of output channels in weight Tensor (" << weight.size(0) - << ") and gradOutput Tensor (" << gradOutput.size(1) << ") in NNPack convolution updateGradInput"; - throw std::runtime_error(err.str()); - } - if (input.size(0) != gradOutput.size(0)) { - std::stringstream err; - err << "Mismatch between batch size in input Tensor (" << input.size(0) - << ") and gradOutput Tensor (" << gradOutput.size(0) << ") in NNPack convolution updateGradInput"; - throw std::runtime_error(err.str()); - } - - // Setup parameters for the NNPACK convolution input gradient call - - // Use the default algorithm - auto algorithm = nnp_convolution_algorithm_auto; - - const size_t batch_size = input.size(0); - const size_t input_channels = input.size(1); - const size_t output_channels = weight.size(0); - const struct nnp_size input_size = { - .width = (size_t)input.size(3), - .height = (size_t)input.size(2) - }; - const struct nnp_padding input_padding = { - .top = (size_t)padH, - .right = (size_t)padW, - .bottom = (size_t)padH, - .left = (size_t)padW - }; - const struct nnp_size kernel_size = { - .width = (size_t)kW, - .height = (size_t)kH - }; - - auto run = [&]() -> nnp_status { - return nnp_convolution_input_gradient( - algorithm, - batch_size, - input_channels, - output_channels, - input_size, - input_padding, - kernel_size, - (float*)gradOutput.data_ptr(), - (float*)weight.data_ptr(), - (float*)gradInput.data_ptr(), - workspace, // workspace_buffer - &workspace_size, // workspace_size - nnp_activation_identity, - nullptr, // activation_parameters - nnpack_threadpool(), - nullptr // profile - ); - }; - - auto size_and_allocate_ws = [&]() { - // Run a single pass to get the size of memory workspace buffer - auto status = run(); - if (status != nnp_status_success) { - throw std::runtime_error("NNPACK SpatialConvolution_updateGradInput failed"); - } - allocate_workspace(); - }; - - // If no workspace created yet, allocate it - if (workspace == nullptr) { - size_and_allocate_ws(); - } - - // Try to run with the newly created, or existing workspace - auto status = run(); - - if (status == nnp_status_insufficient_buffer) { - // Need to reallocate the workspace - deallocate_workspace(); - size_and_allocate_ws(); - - // Try one more time - status = run(); - } - - if (status != nnp_status_success) { - throw std::runtime_error("NNPACK SpatialConvolution_updateGradInput failed"); - } - - return gradInput; -} - -Tensor nnpack_spatial_convolution_backward_weight( - const at::Tensor& input, - IntList weight_size, - const at::Tensor& gradOutput, - int64_t kW, - int64_t kH, - int64_t padW, - int64_t padH) { - - at::Tensor gradWeight = input.type().tensor(weight_size); - - // Our input and gradInput Tensors must be in the form N,C,H,W - if (input.ndimension() != 4) { - throw std::runtime_error("NNPack convolutionOutput expects 4D input Tensor N,C,H,W"); - } - // Our gradWeight Tensor must be in the form oC,iC,kH,kW - if (gradWeight.ndimension() != 4) { - throw std::runtime_error("NNPack convolutionOutput expects 4D gradWeight Tensor oC,iC,kH,kW"); - } - // Our weight Tensor must be in the form N,oC,oH,oW - if (gradOutput.ndimension() != 4) { - throw std::runtime_error("NNPack convolutionOutput expects 4D gradOutput Tensor N,oC,oH,oW"); - } - - // Some basic shape checking, not comprehensive - if (input.size(1) != gradWeight.size(1)) { - std::stringstream err; - err << "Mismatch between number of input channels in input Tensor (" << input.size(1) - << ") and gradWeight Tensor (" << gradWeight.size(1) << ") in NNPack convolution accGradWeight"; - throw std::runtime_error(err.str()); - } - if (gradWeight.size(0) != gradOutput.size(1)) { - std::stringstream err; - err << "Mismatch between number of output channels in gradWeight Tensor (" << gradWeight.size(0) - << ") and gradOutput Tensor (" << gradOutput.size(1) << ") in NNPack convolution accGradWeight"; - throw std::runtime_error(err.str()); - } - if (input.size(0) != gradOutput.size(0)) { - std::stringstream err; - err << "Mismatch between batch size in input Tensor (" << input.size(0) - << ") and gradOutput Tensor (" << gradOutput.size(0) << ") in NNPack convolution accGradWeight"; - throw std::runtime_error(err.str()); - } - - // Setup parameters for the NNPACK convolution kernel gradient call - - // Use the default algorithm - auto algorithm = nnp_convolution_algorithm_auto; - - const size_t batch_size = input.size(0); - const size_t input_channels = input.size(1); - const size_t output_channels = gradWeight.size(0); - const struct nnp_size input_size = { - .width = (size_t)input.size(3), - .height = (size_t)input.size(2) - }; - const struct nnp_padding input_padding = { - .top = (size_t)padH, - .right = (size_t)padW, - .bottom = (size_t)padH, - .left = (size_t)padW - }; - const struct nnp_size kernel_size = { - .width = (size_t)kW, - .height = (size_t)kH - }; - - auto run= [&]() -> nnp_status { - return nnp_convolution_kernel_gradient( - algorithm, - batch_size, - input_channels, - output_channels, - input_size, - input_padding, - kernel_size, - (float*)input.data_ptr(), - (float*)gradOutput.data_ptr(), - (float*)gradWeight.data_ptr(), - workspace, // workspace_buffer - &workspace_size, // workspace_size - nnp_activation_identity, - nullptr, // activation_parameters - nnpack_threadpool(), - nullptr // profile - ); - }; - - auto size_and_allocate_ws = [&]() { - // Run a single pass to get the size of memory workspace buffer - auto status = run(); - if (status != nnp_status_success) { - throw std::runtime_error("NNPACK SpatialConvolution_accGradWeight failed"); - } - allocate_workspace(); - }; - - // If no workspace created yet, allocate it - if (workspace == nullptr) { - size_and_allocate_ws(); - } - - // Try to run with the newly created, or existing workspace - auto status = run(); - - if (status == nnp_status_insufficient_buffer) { - // Need to reallocate the workspace - deallocate_workspace(); - size_and_allocate_ws(); - - // Try one more time - status = run(); - } - - if (status != nnp_status_success) { - throw std::runtime_error("NNPACK SpatialConvolution_accGradWeight failed"); - } - - return gradWeight; -} - -std::tuple<Tensor,Tensor,Tensor> nnpack_spatial_convolution_backward( - const at::Tensor& input, - const at::Tensor& grad_output, - const at::Tensor& weight, - int64_t kW, - int64_t kH, - int64_t padW, - int64_t padH, - std::array<bool,3> output_mask) { - - Tensor grad_input, grad_weight, grad_bias; - if (output_mask[0]) { - grad_input = at::nnpack_spatial_convolution_backward_input(input, grad_output, weight, kW, kH, padW, padH); - } - if (output_mask[1]) { - grad_weight = at::nnpack_spatial_convolution_backward_weight(input, weight.sizes(), grad_output, kW, kH, padW, padH); - } - if (output_mask[2]) { - grad_bias = grad_output.contiguous().view({grad_output.size(0), grad_output.size(1), -1}).sum(0).sum(1); - } - - return std::tuple<Tensor,Tensor,Tensor>{grad_input, grad_weight, grad_bias}; - -} - -}} // at::native - -#endif // AT_NNPACK_ENABLED diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d5859cc0bf..05777a0086 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -440,20 +440,6 @@ - func: narrow(Tensor self, int64_t dim, int64_t start, int64_t length) -> Tensor -# TODO: Why does kW come before kH? Hella confusing, because it -# doesn't match the input layout. -- func: nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, int64_t kW, int64_t kH, int64_t padW, int64_t padH) -> Tensor - variants: function - -- func: nnpack_spatial_convolution_backward(Tensor input, Tensor grad_output, Tensor weight, int64_t kW, int64_t kH, int64_t padW, int64_t padH, std::array<bool,3> output_mask) -> (Tensor, Tensor, Tensor) - variants: function - -- func: nnpack_spatial_convolution_backward_input(Tensor input, Tensor grad_output, Tensor weight, int64_t kW, int64_t kH, int64_t padW, int64_t padH) -> Tensor - variants: function - -- func: nnpack_spatial_convolution_backward_weight(Tensor input, IntList weight_size, Tensor grad_output, int64_t kW, int64_t kH, int64_t padW, int64_t padH) -> Tensor - variants: function - - func: ones(Type dtype, IntList size) -> Tensor variants: function diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 1cb4afb111..068508b16a 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1117,13 +1117,6 @@ save_var: not_implemented("cudnn_batch_norm_backward save_var") input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask) -# nnpack - -- name: nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor bias, int64_t kW, int64_t kH, int64_t padW, int64_t padH) - input: nnpack_spatial_convolution_backward_input(input, grad, weight, kW, kH, padW, padH) - weight: nnpack_spatial_convolution_backward_weight(input, weight.sizes(), grad, kW, kH, padW, padH) - bias: grad.contiguous().view({grad.size(0), grad.size(1), -1}).sum(0).sum(1) - - name: _cudnn_rnn(Tensor input, TensorList weight, int64_t weight_stride0, Tensor weight_buf, Tensor hx, Tensor cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntList batch_sizes, Tensor dropout_state) input, hx, cx, weight: "_cudnn_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" |