summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEli Uriegas <1700823+seemethere@users.noreply.github.com>2021-12-09 08:59:45 -0800
committerGitHub <noreply@github.com>2021-12-09 08:59:45 -0800
commit302ee7bfb604ebef384602c56e3853efed262030 (patch)
treecdb68f233ee6e7a7b4ede1762194fe79e86e60a5
parent0c91a7063d0b88c08b3e10be89587954d86d5e33 (diff)
downloadpytorch-302ee7bfb604ebef384602c56e3853efed262030.tar.gz
pytorch-302ee7bfb604ebef384602c56e3853efed262030.tar.bz2
pytorch-302ee7bfb604ebef384602c56e3853efed262030.zip
[release/1.10] Fix adaptive_max_pool2d for channels-last on CUDA (#67697) (#69618)
Co-authored-by: Xiao Wang <24860335+xwang233@users.noreply.github.com>
-rw-r--r--aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu38
-rw-r--r--test/test_nn.py1
2 files changed, 27 insertions, 12 deletions
diff --git a/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu b/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu
index 57c92445b5..0f78016897 100644
--- a/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu
+++ b/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu
@@ -211,6 +211,9 @@ const Tensor& indices) {
int64_t osizeH = output_size[0];
int64_t osizeW = output_size[1];
+ const at::Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
+ const at::Tensor indices_c = indices.is_contiguous() ? indices : at::empty(indices.sizes(), indices.options());
+
if (input.ndimension() == 3) {
int64_t sizeD = input.size(0);
int64_t isizeH = input.size(1);
@@ -223,8 +226,8 @@ const Tensor& indices) {
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf, kBFloat16, input.scalar_type(), "adaptive_max_pool2d_cuda", [&] {
scalar_t* input_data = input.data_ptr<scalar_t>();
- scalar_t* output_data = output.data_ptr<scalar_t>();
- int64_t* indices_data = indices.data_ptr<int64_t>();
+ scalar_t* output_data = output_c.data_ptr<scalar_t>();
+ int64_t* indices_data = indices_c.data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
@@ -268,8 +271,8 @@ const Tensor& indices) {
"adaptive_max_pool2d_cuda",
[&] {
scalar_t* input_data = input_.data_ptr<scalar_t>();
- scalar_t* output_data = output.data_ptr<scalar_t>();
- int64_t* indices_data = indices.data_ptr<int64_t>();
+ scalar_t* output_data = output_c.data_ptr<scalar_t>();
+ int64_t* indices_data = indices_c.data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
@@ -296,6 +299,13 @@ const Tensor& indices) {
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
+
+ if (!output.is_contiguous()) {
+ output.copy_(output_c);
+ }
+ if (!indices.is_contiguous()) {
+ indices.copy_(indices_c);
+ }
}
TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
@@ -322,7 +332,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
bool atomic =
true; // suboptimal, but without atomic it doesn't pass the tests
- Tensor gradOutput_ = gradOutput.contiguous();
+ const at::Tensor gradOutput_ = gradOutput.contiguous();
+ const at::Tensor indices_ = indices.contiguous();
+ const at::Tensor gradInput_c = gradInput.is_contiguous() ? gradInput : at::empty(gradInput.sizes(), gradInput.options());
if (input.ndimension() == 3) {
int64_t sizeD = input.size(0);
@@ -334,7 +346,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
// bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
- gradInput.zero_();
+ gradInput_c.zero_();
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf,
@@ -342,9 +354,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
input.scalar_type(),
"adaptive_max_pool2d_backward_cuda",
[&] {
- scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
+ scalar_t* gradInput_data = gradInput_c.data_ptr<scalar_t>();
scalar_t* gradOutput_data = gradOutput_.data_ptr<scalar_t>();
- int64_t* indices_data = indices.data_ptr<int64_t>();
+ int64_t* indices_data = indices_.data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
@@ -393,7 +405,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
int64_t osizeH = gradOutput_.size(2);
int64_t osizeW = gradOutput_.size(3);
- gradInput.zero_();
+ gradInput_c.zero_();
// bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
@@ -403,9 +415,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
input.scalar_type(),
"adaptive_max_pool2d_backward_cuda",
[&] {
- scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
+ scalar_t* gradInput_data = gradInput_c.data_ptr<scalar_t>();
scalar_t* gradOutput_data = gradOutput_.data_ptr<scalar_t>();
- int64_t* indices_data = indices.data_ptr<int64_t>();
+ int64_t* indices_data = indices_.data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
@@ -446,6 +458,10 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
}
});
}
+
+ if (!gradInput.is_contiguous()) {
+ gradInput.copy_(gradInput_c);
+ }
}
} // at::native
} // at
diff --git a/test/test_nn.py b/test/test_nn.py
index b6dd466faa..147e84d4fd 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -14622,7 +14622,6 @@ class TestNNDeviceType(NNTestCase):
self.assertEqual(a_cuda.grad, a_cpu.grad)
- @onlyCPU
@dtypes(torch.float, torch.double)
def test_adaptive_pooling_max_nhwc(self, device, dtype):
def helper(n, c, h, w, output_height, output_width, contig):