summaryrefslogtreecommitdiff
path: root/aten/src
diff options
context:
space:
mode:
authorTongzhou Wang <SsnL@users.noreply.github.com>2018-04-12 21:11:44 -0400
committerGitHub <noreply@github.com>2018-04-12 21:11:44 -0400
commit56563a0a79625aa965a3d197408510a7964cb862 (patch)
treeff8e07854bd16c6f5be8a20341856bc2e9d3af4c /aten/src
parentbe86500244ed57fc83ccbdd1b815ead0173b1b80 (diff)
downloadpytorch-56563a0a79625aa965a3d197408510a7964cb862.tar.gz
pytorch-56563a0a79625aa965a3d197408510a7964cb862.tar.bz2
pytorch-56563a0a79625aa965a3d197408510a7964cb862.zip
Use THC allocation for CUFFT workspace (#6568)
* use THC allocation for CUFFT * use auto& instead
Diffstat (limited to 'aten/src')
-rw-r--r--aten/src/ATen/native/cuda/SpectralOps.cu22
1 files changed, 15 insertions, 7 deletions
diff --git a/aten/src/ATen/native/cuda/SpectralOps.cu b/aten/src/ATen/native/cuda/SpectralOps.cu
index a2811a68b1..61e159e183 100644
--- a/aten/src/ATen/native/cuda/SpectralOps.cu
+++ b/aten/src/ATen/native/cuda/SpectralOps.cu
@@ -309,9 +309,6 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim,
inembed.begin()); // begin of output
}
}
-
- CufftHandle plan;
- size_t ws = 0;
cudaDataType itype, otype, exec_type;
if (input.type().scalarType() == ScalarType::Float) {
itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
@@ -335,6 +332,17 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim,
// set output
auto output = input.type().tensor(output_sizes);
+ // create plan
+ CufftHandle plan;
+ size_t ws_size = 0;
+ auto& ctx = at::globalContext();
+
+ // set to current stream
+ CUFFT_CHECK(cufftSetStream(plan.get(), ctx.getCurrentCUDAStream()));
+
+ // disable auto allocation of workspace to use THC allocator
+ CUFFT_CHECK(cufftSetAutoAllocation(plan.get(), /* autoAllocate */ 0));
+
// make plan
if (simple_layout) {
// If with unit-stride, we tell cuFFT by setting inembed == onembed == NULL.
@@ -345,7 +353,7 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim,
CUFFT_CHECK(cufftXtMakePlanMany(plan.get(), signal_ndim, signal_sizes.data(),
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
- batch, &ws, exec_type));
+ batch, &ws_size, exec_type));
} else {
// set idist (stride at batch dim)
long long int idist = complex_input ? input.stride(0) >> 1 : input.stride(0);
@@ -366,11 +374,11 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim,
CUFFT_CHECK(cufftXtMakePlanMany(plan.get(), signal_ndim, signal_sizes.data(),
inembed.data(), base_istride, idist, itype,
onembed.data(), base_ostride, odist, otype,
- batch, &ws, exec_type));
+ batch, &ws_size, exec_type));
}
- // set to current stream
- CUFFT_CHECK(cufftSetStream(plan.get(), at::globalContext().getCurrentCUDAStream()));
+ auto ws = ctx.getType(at::Backend::CUDA, at::ScalarType::Byte).tensor({ static_cast<int64_t>(ws_size) });
+ CUFFT_CHECK(cufftSetWorkArea(plan.get(), ws.data_ptr()));
// run
CUFFT_CHECK(cufftXtExec(plan.get(), input.data_ptr(), output.data_ptr(),