diff options
author | Tongzhou Wang <SsnL@users.noreply.github.com> | 2018-04-12 21:11:44 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-12 21:11:44 -0400 |
commit | 56563a0a79625aa965a3d197408510a7964cb862 (patch) | |
tree | ff8e07854bd16c6f5be8a20341856bc2e9d3af4c /aten/src | |
parent | be86500244ed57fc83ccbdd1b815ead0173b1b80 (diff) | |
download | pytorch-56563a0a79625aa965a3d197408510a7964cb862.tar.gz pytorch-56563a0a79625aa965a3d197408510a7964cb862.tar.bz2 pytorch-56563a0a79625aa965a3d197408510a7964cb862.zip |
Use THC allocation for CUFFT workspace (#6568)
* use THC allocation for CUFFT
* use auto& instead
Diffstat (limited to 'aten/src')
-rw-r--r-- | aten/src/ATen/native/cuda/SpectralOps.cu | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/aten/src/ATen/native/cuda/SpectralOps.cu b/aten/src/ATen/native/cuda/SpectralOps.cu index a2811a68b1..61e159e183 100644 --- a/aten/src/ATen/native/cuda/SpectralOps.cu +++ b/aten/src/ATen/native/cuda/SpectralOps.cu @@ -309,9 +309,6 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim, inembed.begin()); // begin of output } } - - CufftHandle plan; - size_t ws = 0; cudaDataType itype, otype, exec_type; if (input.type().scalarType() == ScalarType::Float) { itype = complex_input ? CUDA_C_32F : CUDA_R_32F; @@ -335,6 +332,17 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim, // set output auto output = input.type().tensor(output_sizes); + // create plan + CufftHandle plan; + size_t ws_size = 0; + auto& ctx = at::globalContext(); + + // set to current stream + CUFFT_CHECK(cufftSetStream(plan.get(), ctx.getCurrentCUDAStream())); + + // disable auto allocation of workspace to use THC allocator + CUFFT_CHECK(cufftSetAutoAllocation(plan.get(), /* autoAllocate */ 0)); + // make plan if (simple_layout) { // If with unit-stride, we tell cuFFT by setting inembed == onembed == NULL. @@ -345,7 +353,7 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim, CUFFT_CHECK(cufftXtMakePlanMany(plan.get(), signal_ndim, signal_sizes.data(), /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype, /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype, - batch, &ws, exec_type)); + batch, &ws_size, exec_type)); } else { // set idist (stride at batch dim) long long int idist = complex_input ? input.stride(0) >> 1 : input.stride(0); @@ -366,11 +374,11 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim, CUFFT_CHECK(cufftXtMakePlanMany(plan.get(), signal_ndim, signal_sizes.data(), inembed.data(), base_istride, idist, itype, onembed.data(), base_ostride, odist, otype, - batch, &ws, exec_type)); + batch, &ws_size, exec_type)); } - // set to current stream - CUFFT_CHECK(cufftSetStream(plan.get(), at::globalContext().getCurrentCUDAStream())); + auto ws = ctx.getType(at::Backend::CUDA, at::ScalarType::Byte).tensor({ static_cast<int64_t>(ws_size) }); + CUFFT_CHECK(cufftSetWorkArea(plan.get(), ws.data_ptr())); // run CUFFT_CHECK(cufftXtExec(plan.get(), input.data_ptr(), output.data_ptr(), |