diff options
author | Syed Tousif Ahmed <syed.ahmed.emails@gmail.com> | 2019-02-01 12:38:15 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-01 13:07:36 -0800 |
commit | 6d373c02ef3b37375c04cd058e418e4124000513 (patch) | |
tree | 308c92884a5c3f21bb9cffadde5de0c557492723 /aten | |
parent | 638dbe4b4646ea547e483433002d5ba71441da02 (diff) | |
download | pytorch-6d373c02ef3b37375c04cd058e418e4124000513.tar.gz pytorch-6d373c02ef3b37375c04cd058e418e4124000513.tar.bz2 pytorch-6d373c02ef3b37375c04cd058e418e4124000513.zip |
Revert "Fixes selection of cuDNN algorithm (#15881)" (#16484)
Summary:
There is a regression in cudnnGet*_v7 that causes slowdown in resnet50 training. I am opening a bug with cuDNN team about this. This reverts commit 38374468832e307ca741901870914857a836dd5d.
ezyang :crying_cat_face:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16484
Differential Revision: D13924755
Pulled By: soumith
fbshipit-source-id: 8c719345fc443f1289539bfae630eea9224ba4a5
Diffstat (limited to 'aten')
-rw-r--r-- | aten/src/ATen/cudnn/Descriptors.h | 3 | ||||
-rw-r--r-- | aten/src/ATen/native/cudnn/Conv.cpp | 328 |
2 files changed, 143 insertions, 188 deletions
diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h index be00b9dd2b..6c3970e13d 100644 --- a/aten/src/ATen/cudnn/Descriptors.h +++ b/aten/src/ATen/cudnn/Descriptors.h @@ -155,6 +155,9 @@ struct AT_CUDA_API ConvolutionDescriptor AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, CUDNN_CROSS_CORRELATION, mathType)); AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups)); + AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH)); + if(dataType == CUDNN_DATA_HALF) + AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH)); } }; diff --git a/aten/src/ATen/native/cudnn/Conv.cpp b/aten/src/ATen/native/cudnn/Conv.cpp index 2fb265f3b5..f0923a6930 100644 --- a/aten/src/ATen/native/cudnn/Conv.cpp +++ b/aten/src/ATen/native/cudnn/Conv.cpp @@ -94,25 +94,6 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_convolution_transpose_backwar #include <stdint.h> #include <unordered_map> -// Note [chooseAlgorithm doesn't respect mathType] -// You might be wondering, why are we calling cudnnSetConvolutionMathType after -// calling chooseAlgorithm... -// Turns out, the mathType returned by the chooseAlgorithm can be different -// from what we set in the descriptor and hence, we have to explicitly update it -// after the chooseAlgorithm has found the best pair of algorithm+mathType. -// Otherwise, even though we'll be calling cudnnConvolutionForward with the -// fastest algorithm, under the hood, cudnn will run it with the slower kernel -// since it sees fastest algorithm combination with a sub optimal mathType. - -// Note [cudnnSetConvolutionMathType cannot be called in descriptor] -// When cudnnSetConvolutionMathType is called before cudnnGetConvolutionForwardAlgorithm_v7, -// cudnnGet finds an algorithm based on the mathType set by cudnnSetConvolutionMathType. -// That is, if we call cudnnSetConvolutionMathType in the setter of the descriptor -// (to have some default values, e.g. CUDNN_TENSOR_OP when fp16), cudnnGet*_v7 returns -// algo1 with CUDNN_TENSOR_OP math type, instead of not caring about what was set by -// cudnnSetConvolutionMathType before it (and returning algo1 with CUDNN_DEFAULT_MATH -// which is performant). A bug has been filed internally at NVIDIA. - namespace at { namespace native { // TODO: Go through all the checking code again and make sure @@ -359,9 +340,9 @@ struct BenchmarkCache { } }; -BenchmarkCache<cudnnConvolutionFwdAlgoPerf_t> fwd_algos; -BenchmarkCache<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_algos; -BenchmarkCache<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filter_algos; +BenchmarkCache<cudnnConvolutionFwdAlgo_t> fwd_algos; +BenchmarkCache<cudnnConvolutionBwdDataAlgo_t> bwd_data_algos; +BenchmarkCache<cudnnConvolutionBwdFilterAlgo_t> bwd_filter_algos; // TODO: Stop manually allocating CUDA memory; allocate an ATen byte // tensor instead. @@ -382,7 +363,7 @@ struct Workspace { void* data; }; -template<typename perf_t> +template<typename algo_t> struct algorithm_search { }; @@ -471,14 +452,14 @@ perf_t getBestAlgorithm(perf_t *perfResults, bool deterministic, int n_algo) { } template<> -struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> { +struct algorithm_search<cudnnConvolutionFwdAlgo_t> { using perf_t = cudnnConvolutionFwdAlgoPerf_t; using algo_t = cudnnConvolutionFwdAlgo_t; static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; - static BenchmarkCache<perf_t>& cache() { return fwd_algos; } + static BenchmarkCache<algo_t>& cache() { return fwd_algos; } - static perf_t findAlgorithm(const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark) { + static perf_t findAlgorithm(const ConvolutionArgs& args) { static const algo_t algos[] = { CUDNN_CONVOLUTION_FWD_ALGO_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_FFT, @@ -494,32 +475,36 @@ struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> { "Missing cuDNN convolution forward algorithms"); int perf_count; std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]); - if (!benchmark) { - AT_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7( - args.handle, - args.idesc.desc(), - args.wdesc.desc(), - args.cdesc.desc(), - args.odesc.desc(), - num_algos, - &perf_count, - perf_results.get())); - } else { - size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); - Workspace ws(max_ws_size); - AT_CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx( - args.handle, - args.idesc.desc(), args.input.data_ptr(), - args.wdesc.desc(), args.weight.data_ptr(), - args.cdesc.desc(), - args.odesc.desc(), args.output.data_ptr(), - num_algos, - &perf_count, - perf_results.get(), - ws.data, - ws.size)); - } - return getBestAlgorithm<perf_t>(perf_results.get(), args.params.deterministic, perf_count); + size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); + Workspace ws(max_ws_size); + AT_CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx( + args.handle, + args.idesc.desc(), args.input.data_ptr(), + args.wdesc.desc(), args.weight.data_ptr(), + args.cdesc.desc(), + args.odesc.desc(), args.output.data_ptr(), + num_algos, + &perf_count, + perf_results.get(), + ws.data, + ws.size)); + return getBestAlgorithm(perf_results.get(), args.params.deterministic, perf_count); + } + + static void getAlgorithm( + const ConvolutionArgs& args, + algo_t* algo) + { + cudnnConvolutionFwdPreference_t pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST; + AT_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm( + args.handle, + args.idesc.desc(), + args.wdesc.desc(), + args.cdesc.desc(), + args.odesc.desc(), + pref, + 0, + algo)); } static void getWorkspaceSize( @@ -538,14 +523,14 @@ struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> { }; template<> -struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> { +struct algorithm_search<cudnnConvolutionBwdDataAlgo_t> { using perf_t = cudnnConvolutionBwdDataAlgoPerf_t; using algo_t = cudnnConvolutionBwdDataAlgo_t; static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - static BenchmarkCache<perf_t>& cache() { return bwd_data_algos; } + static BenchmarkCache<algo_t>& cache() { return bwd_data_algos; } - static perf_t findAlgorithm(const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark) { + static perf_t findAlgorithm(const ConvolutionArgs& args) { static const algo_t algos[] = { CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, @@ -559,32 +544,32 @@ struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> { "Missing cuDNN convolution backward data algorithms."); int perf_count; std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]); - if (!benchmark) { - AT_CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm_v7( - args.handle, - args.wdesc.desc(), - args.odesc.desc(), - args.cdesc.desc(), - args.idesc.desc(), - num_algos, - &perf_count, - perf_results.get())); - } else { - size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); - Workspace ws(max_ws_size); - AT_CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx( - args.handle, - args.wdesc.desc(), args.weight.data_ptr(), - args.odesc.desc(), args.output.data_ptr(), - args.cdesc.desc(), - args.idesc.desc(), args.input.data_ptr(), - num_algos, - &perf_count, - perf_results.get(), - ws.data, - ws.size)); - } - return getBestAlgorithm<perf_t>(perf_results.get(), args.params.deterministic, perf_count); + size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); + Workspace ws(max_ws_size); + AT_CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx( + args.handle, + args.wdesc.desc(), args.weight.data_ptr(), + args.odesc.desc(), args.output.data_ptr(), + args.cdesc.desc(), + args.idesc.desc(), args.input.data_ptr(), + num_algos, + &perf_count, + perf_results.get(), + ws.data, + ws.size)); + return getBestAlgorithm(perf_results.get(), args.params.deterministic, perf_count); + } + + static void getAlgorithm(const ConvolutionArgs& args, algo_t* algo) { + AT_CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm( + args.handle, + args.wdesc.desc(), + args.odesc.desc(), + args.cdesc.desc(), + args.idesc.desc(), + CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, + 0, + algo)); } static void getWorkspaceSize( @@ -603,15 +588,15 @@ struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> { }; template<> -struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> { +struct algorithm_search<cudnnConvolutionBwdFilterAlgo_t> { using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t; using algo_t = cudnnConvolutionBwdFilterAlgo_t; static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; - static BenchmarkCache<perf_t>& cache() { return bwd_filter_algos; } + static BenchmarkCache<algo_t>& cache() { return bwd_filter_algos; } - static perf_t findAlgorithm(const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark) { + static perf_t findAlgorithm(const ConvolutionArgs& args) { static const algo_t algos[] = { CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, @@ -625,35 +610,37 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> { static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]); + size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); int perf_count; - if (!benchmark) { - AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm_v7( - args.handle, - args.idesc.desc(), - args.odesc.desc(), - args.cdesc.desc(), - args.wdesc.desc(), - num_algos, - &perf_count, - perf_results.get())); - } else { - size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); - Workspace ws(max_ws_size); - AT_CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx( - args.handle, - args.idesc.desc(), args.input.data_ptr(), - args.odesc.desc(), args.output.data_ptr(), - args.cdesc.desc(), - args.wdesc.desc(), args.weight.data_ptr(), - num_algos, - &perf_count, - perf_results.get(), - ws.data, - ws.size)); - } + Workspace ws(max_ws_size); + + AT_CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx( + args.handle, + args.idesc.desc(), args.input.data_ptr(), + args.odesc.desc(), args.output.data_ptr(), + args.cdesc.desc(), + args.wdesc.desc(), args.weight.data_ptr(), + num_algos, + &perf_count, + perf_results.get(), + ws.data, + ws.size)); return getBestAlgorithm<perf_t>(perf_results.get(), args.params.deterministic, perf_count); } + static void getAlgorithm(const ConvolutionArgs& args, algo_t* algo) { + AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm( + args.handle, + args.idesc.desc(), + args.odesc.desc(), + args.cdesc.desc(), + args.wdesc.desc(), + CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, + 0, + algo) + ); + } + static void getWorkspaceSize(const ConvolutionArgs& args, algo_t algo, size_t* workspaceSize) { AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( @@ -667,90 +654,70 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> { } }; -template<typename perf_t> -void findAlgorithm(const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark, perf_t* algoPerf) { - using search = algorithm_search<perf_t>; +template<typename algo_t> +void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) { + using search = algorithm_search<algo_t>; auto& cache = search::cache(); - if (cache.find(args.params, algoPerf)) { + if (cache.find(args.params, algo)) { return; } if (args.params.deterministic && !benchmark) { - algoPerf->algo = search::DEFAULT_ALGO; - // Note [cudnnSetConvolutionMathType cannot be called in descriptor] - if (dataType == CUDNN_DATA_HALF) { - algoPerf->mathType = CUDNN_TENSOR_OP_MATH; - } else { - algoPerf->mathType = CUDNN_DEFAULT_MATH; - } - search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory)); + *algo = search::DEFAULT_ALGO; return; } - if (benchmark) { - if (cache.find(args.params, algoPerf)) { - // re-check cache since another thread may have benchmarked the algorithm - return; - } - } + if (!benchmark) { + search::getAlgorithm(args, algo); + return; + } - auto perfResults = search::findAlgorithm(dataType, args, benchmark); + if (cache.find(args.params, algo)) { + // re-check cache since another thread may have benchmarked the algorithm + return; + } + + auto perfResults = search::findAlgorithm(args); // for deterministic algo, look at all the perf results and return the best // deterministic algo if (perfResults.status == CUDNN_STATUS_SUCCESS && !(args.params.deterministic && perfResults.determinism != CUDNN_DETERMINISTIC)) { - - // if benchmarking, map the original params with the found algo+math type for re-use - if (benchmark) { - cache.insert(args.params, perfResults); - - // Free the cached blocks in our caching allocator. They are - // needed here because the above benchmarking uses a huge amount of memory, - // e.g. a few GBs. - c10::cuda::CUDACachingAllocator::emptyCache(); - } - - *algoPerf = perfResults; + *algo = perfResults.algo; } else { - algoPerf->algo = search::DEFAULT_ALGO; - // Note [cudnnSetConvolutionMathType cannot be called in descriptor] - if (dataType == CUDNN_DATA_HALF) { - algoPerf->mathType = CUDNN_TENSOR_OP_MATH; - } else { - algoPerf->mathType = CUDNN_DEFAULT_MATH; - } - search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory)); + *algo = search::DEFAULT_ALGO; } + cache.insert(args.params, *algo); + + // Free the cached blocks in our caching allocator. They are + // needed here because the above benchmarking uses a huge amount of memory, + // e.g. a few GBs. + c10::cuda::CUDACachingAllocator::emptyCache(); } -template<typename perf_t> +template<typename algo_t> Workspace chooseAlgorithm( - const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark, - perf_t* algoPerf) + algo_t* algo) { - findAlgorithm(dataType, args, benchmark, algoPerf); + findAlgorithm(args, benchmark, algo); - using search = algorithm_search<perf_t>; + using search = algorithm_search<algo_t>; + size_t workspace_size; + search::getWorkspaceSize(args, *algo, &workspace_size); try { - return Workspace(algoPerf->memory); + return Workspace(workspace_size); } catch (const std::exception& e) { cudaGetLastError(); // clear OOM error // switch to default algorithm and record it in the cache to prevent // further OOM errors - algoPerf->algo = search::DEFAULT_ALGO; - // Note [cudnnSetConvolutionMathType cannot be called in descriptor] - if (dataType == CUDNN_DATA_HALF) { - algoPerf->mathType = CUDNN_TENSOR_OP_MATH; - } else { - algoPerf->mathType = CUDNN_DEFAULT_MATH; - } - search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory)); - search::cache().insert(args.params, *algoPerf); - return Workspace(algoPerf->memory); + *algo = search::DEFAULT_ALGO; + search::cache().insert(args.params, *algo); + + search::getWorkspaceSize(args, *algo, &workspace_size); + return Workspace(workspace_size); } } @@ -844,13 +811,8 @@ void raw_cudnn_convolution_forward_out( // wasteful; we'd rather reuse the workspace. OTOH, legacy group // convolution support is already pretty slow, so this might not // matter. (This applies to raw_cudnn_convolution_backward_input as well.) - cudnnConvolutionFwdAlgoPerf_t fwdAlgPerf; - Workspace workspace = chooseAlgorithm(dataType, args, benchmark, &fwdAlgPerf); - - // update convDesc mathType since cudnn now requires both algo + mathType to figure out - // whether to use Tensor cores or not - // See Note [chooseAlgorithm doesn't respect mathType] - AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), fwdAlgPerf.mathType)); + cudnnConvolutionFwdAlgo_t fwdAlg; + Workspace workspace = chooseAlgorithm(args, benchmark, &fwdAlg); Constant one(dataType, 1); Constant zero(dataType, 0); @@ -859,7 +821,7 @@ void raw_cudnn_convolution_forward_out( args.handle, &one, args.idesc.desc(), input.data_ptr(), args.wdesc.desc(), weight.data_ptr(), - args.cdesc.desc(), fwdAlgPerf.algo, workspace.data, workspace.size, + args.cdesc.desc(), fwdAlg, workspace.data, workspace.size, &zero, args.odesc.desc(), output.data_ptr())); } @@ -968,13 +930,8 @@ void raw_cudnn_convolution_backward_input_out( args.odesc.set(grad_output); args.cdesc.set(dataType, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups); - cudnnConvolutionBwdDataAlgoPerf_t bwdDataAlgPerf; - Workspace workspace = chooseAlgorithm(dataType, args, benchmark, &bwdDataAlgPerf); - - // update convDesc mathType since cudnn now requires both algo + mathType to figure out - // whether to use Tensor cores or not - // See Note [chooseAlgorithm doesn't respect mathType] - AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdDataAlgPerf.mathType)); + cudnnConvolutionBwdDataAlgo_t bwdDataAlg; + Workspace workspace = chooseAlgorithm(args, benchmark, &bwdDataAlg); Constant one(dataType, 1); Constant zero(dataType, 0); @@ -983,7 +940,7 @@ void raw_cudnn_convolution_backward_input_out( args.handle, &one, args.wdesc.desc(), weight.data_ptr(), args.odesc.desc(), grad_output.data_ptr(), - args.cdesc.desc(), bwdDataAlgPerf.algo, workspace.data, workspace.size, + args.cdesc.desc(), bwdDataAlg, workspace.data, workspace.size, &zero, args.idesc.desc(), grad_input.data_ptr())); } @@ -1109,13 +1066,8 @@ void raw_cudnn_convolution_backward_weight_out( args.odesc.set(grad_output); args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups); - cudnnConvolutionBwdFilterAlgoPerf_t bwdFilterAlgPerf; - Workspace workspace = chooseAlgorithm(dataType, args, benchmark, &bwdFilterAlgPerf); - - // update convDesc mathType since cudnn now requires both algo + mathType to figure out - // whether to use Tensor cores or not - // See Note [chooseAlgorithm doesn't respect mathType] - AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdFilterAlgPerf.mathType)); + cudnnConvolutionBwdFilterAlgo_t bwdFilterAlg; + Workspace workspace = chooseAlgorithm(args, benchmark, &bwdFilterAlg); Constant one(dataType, 1); Constant zero(dataType, 0); @@ -1124,7 +1076,7 @@ void raw_cudnn_convolution_backward_weight_out( args.handle, &one, args.idesc.desc(), input.data_ptr(), args.odesc.desc(), grad_output.data_ptr(), - args.cdesc.desc(), bwdFilterAlgPerf.algo, workspace.data, workspace.size, + args.cdesc.desc(), bwdFilterAlg, workspace.data, workspace.size, &zero, args.wdesc.desc(), grad_weight.data_ptr())); } |