summaryrefslogtreecommitdiff
path: root/aten
diff options
context:
space:
mode:
authorSyed Tousif Ahmed <syed.ahmed.emails@gmail.com>2019-02-01 12:38:15 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-01 13:07:36 -0800
commit6d373c02ef3b37375c04cd058e418e4124000513 (patch)
tree308c92884a5c3f21bb9cffadde5de0c557492723 /aten
parent638dbe4b4646ea547e483433002d5ba71441da02 (diff)
downloadpytorch-6d373c02ef3b37375c04cd058e418e4124000513.tar.gz
pytorch-6d373c02ef3b37375c04cd058e418e4124000513.tar.bz2
pytorch-6d373c02ef3b37375c04cd058e418e4124000513.zip
Revert "Fixes selection of cuDNN algorithm (#15881)" (#16484)
Summary: There is a regression in cudnnGet*_v7 that causes slowdown in resnet50 training. I am opening a bug with cuDNN team about this. This reverts commit 38374468832e307ca741901870914857a836dd5d. ezyang :crying_cat_face: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16484 Differential Revision: D13924755 Pulled By: soumith fbshipit-source-id: 8c719345fc443f1289539bfae630eea9224ba4a5
Diffstat (limited to 'aten')
-rw-r--r--aten/src/ATen/cudnn/Descriptors.h3
-rw-r--r--aten/src/ATen/native/cudnn/Conv.cpp328
2 files changed, 143 insertions, 188 deletions
diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h
index be00b9dd2b..6c3970e13d 100644
--- a/aten/src/ATen/cudnn/Descriptors.h
+++ b/aten/src/ATen/cudnn/Descriptors.h
@@ -155,6 +155,9 @@ struct AT_CUDA_API ConvolutionDescriptor
AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale,
CUDNN_CROSS_CORRELATION, mathType));
AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups));
+ AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
+ if(dataType == CUDNN_DATA_HALF)
+ AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
}
};
diff --git a/aten/src/ATen/native/cudnn/Conv.cpp b/aten/src/ATen/native/cudnn/Conv.cpp
index 2fb265f3b5..f0923a6930 100644
--- a/aten/src/ATen/native/cudnn/Conv.cpp
+++ b/aten/src/ATen/native/cudnn/Conv.cpp
@@ -94,25 +94,6 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_convolution_transpose_backwar
#include <stdint.h>
#include <unordered_map>
-// Note [chooseAlgorithm doesn't respect mathType]
-// You might be wondering, why are we calling cudnnSetConvolutionMathType after
-// calling chooseAlgorithm...
-// Turns out, the mathType returned by the chooseAlgorithm can be different
-// from what we set in the descriptor and hence, we have to explicitly update it
-// after the chooseAlgorithm has found the best pair of algorithm+mathType.
-// Otherwise, even though we'll be calling cudnnConvolutionForward with the
-// fastest algorithm, under the hood, cudnn will run it with the slower kernel
-// since it sees fastest algorithm combination with a sub optimal mathType.
-
-// Note [cudnnSetConvolutionMathType cannot be called in descriptor]
-// When cudnnSetConvolutionMathType is called before cudnnGetConvolutionForwardAlgorithm_v7,
-// cudnnGet finds an algorithm based on the mathType set by cudnnSetConvolutionMathType.
-// That is, if we call cudnnSetConvolutionMathType in the setter of the descriptor
-// (to have some default values, e.g. CUDNN_TENSOR_OP when fp16), cudnnGet*_v7 returns
-// algo1 with CUDNN_TENSOR_OP math type, instead of not caring about what was set by
-// cudnnSetConvolutionMathType before it (and returning algo1 with CUDNN_DEFAULT_MATH
-// which is performant). A bug has been filed internally at NVIDIA.
-
namespace at { namespace native {
// TODO: Go through all the checking code again and make sure
@@ -359,9 +340,9 @@ struct BenchmarkCache {
}
};
-BenchmarkCache<cudnnConvolutionFwdAlgoPerf_t> fwd_algos;
-BenchmarkCache<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_algos;
-BenchmarkCache<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filter_algos;
+BenchmarkCache<cudnnConvolutionFwdAlgo_t> fwd_algos;
+BenchmarkCache<cudnnConvolutionBwdDataAlgo_t> bwd_data_algos;
+BenchmarkCache<cudnnConvolutionBwdFilterAlgo_t> bwd_filter_algos;
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
@@ -382,7 +363,7 @@ struct Workspace {
void* data;
};
-template<typename perf_t>
+template<typename algo_t>
struct algorithm_search {
};
@@ -471,14 +452,14 @@ perf_t getBestAlgorithm(perf_t *perfResults, bool deterministic, int n_algo) {
}
template<>
-struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
+struct algorithm_search<cudnnConvolutionFwdAlgo_t> {
using perf_t = cudnnConvolutionFwdAlgoPerf_t;
using algo_t = cudnnConvolutionFwdAlgo_t;
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
- static BenchmarkCache<perf_t>& cache() { return fwd_algos; }
+ static BenchmarkCache<algo_t>& cache() { return fwd_algos; }
- static perf_t findAlgorithm(const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark) {
+ static perf_t findAlgorithm(const ConvolutionArgs& args) {
static const algo_t algos[] = {
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
@@ -494,32 +475,36 @@ struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
"Missing cuDNN convolution forward algorithms");
int perf_count;
std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
- if (!benchmark) {
- AT_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7(
- args.handle,
- args.idesc.desc(),
- args.wdesc.desc(),
- args.cdesc.desc(),
- args.odesc.desc(),
- num_algos,
- &perf_count,
- perf_results.get()));
- } else {
- size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
- Workspace ws(max_ws_size);
- AT_CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx(
- args.handle,
- args.idesc.desc(), args.input.data_ptr(),
- args.wdesc.desc(), args.weight.data_ptr(),
- args.cdesc.desc(),
- args.odesc.desc(), args.output.data_ptr(),
- num_algos,
- &perf_count,
- perf_results.get(),
- ws.data,
- ws.size));
- }
- return getBestAlgorithm<perf_t>(perf_results.get(), args.params.deterministic, perf_count);
+ size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
+ Workspace ws(max_ws_size);
+ AT_CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx(
+ args.handle,
+ args.idesc.desc(), args.input.data_ptr(),
+ args.wdesc.desc(), args.weight.data_ptr(),
+ args.cdesc.desc(),
+ args.odesc.desc(), args.output.data_ptr(),
+ num_algos,
+ &perf_count,
+ perf_results.get(),
+ ws.data,
+ ws.size));
+ return getBestAlgorithm(perf_results.get(), args.params.deterministic, perf_count);
+ }
+
+ static void getAlgorithm(
+ const ConvolutionArgs& args,
+ algo_t* algo)
+ {
+ cudnnConvolutionFwdPreference_t pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
+ AT_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
+ args.handle,
+ args.idesc.desc(),
+ args.wdesc.desc(),
+ args.cdesc.desc(),
+ args.odesc.desc(),
+ pref,
+ 0,
+ algo));
}
static void getWorkspaceSize(
@@ -538,14 +523,14 @@ struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
};
template<>
-struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
+struct algorithm_search<cudnnConvolutionBwdDataAlgo_t> {
using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
using algo_t = cudnnConvolutionBwdDataAlgo_t;
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
- static BenchmarkCache<perf_t>& cache() { return bwd_data_algos; }
+ static BenchmarkCache<algo_t>& cache() { return bwd_data_algos; }
- static perf_t findAlgorithm(const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark) {
+ static perf_t findAlgorithm(const ConvolutionArgs& args) {
static const algo_t algos[] = {
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
@@ -559,32 +544,32 @@ struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
"Missing cuDNN convolution backward data algorithms.");
int perf_count;
std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
- if (!benchmark) {
- AT_CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm_v7(
- args.handle,
- args.wdesc.desc(),
- args.odesc.desc(),
- args.cdesc.desc(),
- args.idesc.desc(),
- num_algos,
- &perf_count,
- perf_results.get()));
- } else {
- size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
- Workspace ws(max_ws_size);
- AT_CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx(
- args.handle,
- args.wdesc.desc(), args.weight.data_ptr(),
- args.odesc.desc(), args.output.data_ptr(),
- args.cdesc.desc(),
- args.idesc.desc(), args.input.data_ptr(),
- num_algos,
- &perf_count,
- perf_results.get(),
- ws.data,
- ws.size));
- }
- return getBestAlgorithm<perf_t>(perf_results.get(), args.params.deterministic, perf_count);
+ size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
+ Workspace ws(max_ws_size);
+ AT_CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx(
+ args.handle,
+ args.wdesc.desc(), args.weight.data_ptr(),
+ args.odesc.desc(), args.output.data_ptr(),
+ args.cdesc.desc(),
+ args.idesc.desc(), args.input.data_ptr(),
+ num_algos,
+ &perf_count,
+ perf_results.get(),
+ ws.data,
+ ws.size));
+ return getBestAlgorithm(perf_results.get(), args.params.deterministic, perf_count);
+ }
+
+ static void getAlgorithm(const ConvolutionArgs& args, algo_t* algo) {
+ AT_CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
+ args.handle,
+ args.wdesc.desc(),
+ args.odesc.desc(),
+ args.cdesc.desc(),
+ args.idesc.desc(),
+ CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
+ 0,
+ algo));
}
static void getWorkspaceSize(
@@ -603,15 +588,15 @@ struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
};
template<>
-struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
+struct algorithm_search<cudnnConvolutionBwdFilterAlgo_t> {
using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
using algo_t = cudnnConvolutionBwdFilterAlgo_t;
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
- static BenchmarkCache<perf_t>& cache() { return bwd_filter_algos; }
+ static BenchmarkCache<algo_t>& cache() { return bwd_filter_algos; }
- static perf_t findAlgorithm(const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark) {
+ static perf_t findAlgorithm(const ConvolutionArgs& args) {
static const algo_t algos[] = {
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
@@ -625,35 +610,37 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos,
"Missing cuDNN convolution backward filter algorithms.");
std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
+ size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
int perf_count;
- if (!benchmark) {
- AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
- args.handle,
- args.idesc.desc(),
- args.odesc.desc(),
- args.cdesc.desc(),
- args.wdesc.desc(),
- num_algos,
- &perf_count,
- perf_results.get()));
- } else {
- size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
- Workspace ws(max_ws_size);
- AT_CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx(
- args.handle,
- args.idesc.desc(), args.input.data_ptr(),
- args.odesc.desc(), args.output.data_ptr(),
- args.cdesc.desc(),
- args.wdesc.desc(), args.weight.data_ptr(),
- num_algos,
- &perf_count,
- perf_results.get(),
- ws.data,
- ws.size));
- }
+ Workspace ws(max_ws_size);
+
+ AT_CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx(
+ args.handle,
+ args.idesc.desc(), args.input.data_ptr(),
+ args.odesc.desc(), args.output.data_ptr(),
+ args.cdesc.desc(),
+ args.wdesc.desc(), args.weight.data_ptr(),
+ num_algos,
+ &perf_count,
+ perf_results.get(),
+ ws.data,
+ ws.size));
return getBestAlgorithm<perf_t>(perf_results.get(), args.params.deterministic, perf_count);
}
+ static void getAlgorithm(const ConvolutionArgs& args, algo_t* algo) {
+ AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
+ args.handle,
+ args.idesc.desc(),
+ args.odesc.desc(),
+ args.cdesc.desc(),
+ args.wdesc.desc(),
+ CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
+ 0,
+ algo)
+ );
+ }
+
static void getWorkspaceSize(const ConvolutionArgs& args, algo_t algo, size_t* workspaceSize)
{
AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
@@ -667,90 +654,70 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
}
};
-template<typename perf_t>
-void findAlgorithm(const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark, perf_t* algoPerf) {
- using search = algorithm_search<perf_t>;
+template<typename algo_t>
+void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) {
+ using search = algorithm_search<algo_t>;
auto& cache = search::cache();
- if (cache.find(args.params, algoPerf)) {
+ if (cache.find(args.params, algo)) {
return;
}
if (args.params.deterministic && !benchmark) {
- algoPerf->algo = search::DEFAULT_ALGO;
- // Note [cudnnSetConvolutionMathType cannot be called in descriptor]
- if (dataType == CUDNN_DATA_HALF) {
- algoPerf->mathType = CUDNN_TENSOR_OP_MATH;
- } else {
- algoPerf->mathType = CUDNN_DEFAULT_MATH;
- }
- search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory));
+ *algo = search::DEFAULT_ALGO;
return;
}
- if (benchmark) {
- if (cache.find(args.params, algoPerf)) {
- // re-check cache since another thread may have benchmarked the algorithm
- return;
- }
- }
+ if (!benchmark) {
+ search::getAlgorithm(args, algo);
+ return;
+ }
- auto perfResults = search::findAlgorithm(dataType, args, benchmark);
+ if (cache.find(args.params, algo)) {
+ // re-check cache since another thread may have benchmarked the algorithm
+ return;
+ }
+
+ auto perfResults = search::findAlgorithm(args);
// for deterministic algo, look at all the perf results and return the best
// deterministic algo
if (perfResults.status == CUDNN_STATUS_SUCCESS &&
!(args.params.deterministic && perfResults.determinism != CUDNN_DETERMINISTIC)) {
-
- // if benchmarking, map the original params with the found algo+math type for re-use
- if (benchmark) {
- cache.insert(args.params, perfResults);
-
- // Free the cached blocks in our caching allocator. They are
- // needed here because the above benchmarking uses a huge amount of memory,
- // e.g. a few GBs.
- c10::cuda::CUDACachingAllocator::emptyCache();
- }
-
- *algoPerf = perfResults;
+ *algo = perfResults.algo;
} else {
- algoPerf->algo = search::DEFAULT_ALGO;
- // Note [cudnnSetConvolutionMathType cannot be called in descriptor]
- if (dataType == CUDNN_DATA_HALF) {
- algoPerf->mathType = CUDNN_TENSOR_OP_MATH;
- } else {
- algoPerf->mathType = CUDNN_DEFAULT_MATH;
- }
- search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory));
+ *algo = search::DEFAULT_ALGO;
}
+ cache.insert(args.params, *algo);
+
+ // Free the cached blocks in our caching allocator. They are
+ // needed here because the above benchmarking uses a huge amount of memory,
+ // e.g. a few GBs.
+ c10::cuda::CUDACachingAllocator::emptyCache();
}
-template<typename perf_t>
+template<typename algo_t>
Workspace chooseAlgorithm(
- const cudnnDataType_t dataType,
const ConvolutionArgs& args,
bool benchmark,
- perf_t* algoPerf)
+ algo_t* algo)
{
- findAlgorithm(dataType, args, benchmark, algoPerf);
+ findAlgorithm(args, benchmark, algo);
- using search = algorithm_search<perf_t>;
+ using search = algorithm_search<algo_t>;
+ size_t workspace_size;
+ search::getWorkspaceSize(args, *algo, &workspace_size);
try {
- return Workspace(algoPerf->memory);
+ return Workspace(workspace_size);
} catch (const std::exception& e) {
cudaGetLastError(); // clear OOM error
// switch to default algorithm and record it in the cache to prevent
// further OOM errors
- algoPerf->algo = search::DEFAULT_ALGO;
- // Note [cudnnSetConvolutionMathType cannot be called in descriptor]
- if (dataType == CUDNN_DATA_HALF) {
- algoPerf->mathType = CUDNN_TENSOR_OP_MATH;
- } else {
- algoPerf->mathType = CUDNN_DEFAULT_MATH;
- }
- search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory));
- search::cache().insert(args.params, *algoPerf);
- return Workspace(algoPerf->memory);
+ *algo = search::DEFAULT_ALGO;
+ search::cache().insert(args.params, *algo);
+
+ search::getWorkspaceSize(args, *algo, &workspace_size);
+ return Workspace(workspace_size);
}
}
@@ -844,13 +811,8 @@ void raw_cudnn_convolution_forward_out(
// wasteful; we'd rather reuse the workspace. OTOH, legacy group
// convolution support is already pretty slow, so this might not
// matter. (This applies to raw_cudnn_convolution_backward_input as well.)
- cudnnConvolutionFwdAlgoPerf_t fwdAlgPerf;
- Workspace workspace = chooseAlgorithm(dataType, args, benchmark, &fwdAlgPerf);
-
- // update convDesc mathType since cudnn now requires both algo + mathType to figure out
- // whether to use Tensor cores or not
- // See Note [chooseAlgorithm doesn't respect mathType]
- AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), fwdAlgPerf.mathType));
+ cudnnConvolutionFwdAlgo_t fwdAlg;
+ Workspace workspace = chooseAlgorithm(args, benchmark, &fwdAlg);
Constant one(dataType, 1);
Constant zero(dataType, 0);
@@ -859,7 +821,7 @@ void raw_cudnn_convolution_forward_out(
args.handle,
&one, args.idesc.desc(), input.data_ptr(),
args.wdesc.desc(), weight.data_ptr(),
- args.cdesc.desc(), fwdAlgPerf.algo, workspace.data, workspace.size,
+ args.cdesc.desc(), fwdAlg, workspace.data, workspace.size,
&zero, args.odesc.desc(), output.data_ptr()));
}
@@ -968,13 +930,8 @@ void raw_cudnn_convolution_backward_input_out(
args.odesc.set(grad_output);
args.cdesc.set(dataType, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
- cudnnConvolutionBwdDataAlgoPerf_t bwdDataAlgPerf;
- Workspace workspace = chooseAlgorithm(dataType, args, benchmark, &bwdDataAlgPerf);
-
- // update convDesc mathType since cudnn now requires both algo + mathType to figure out
- // whether to use Tensor cores or not
- // See Note [chooseAlgorithm doesn't respect mathType]
- AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdDataAlgPerf.mathType));
+ cudnnConvolutionBwdDataAlgo_t bwdDataAlg;
+ Workspace workspace = chooseAlgorithm(args, benchmark, &bwdDataAlg);
Constant one(dataType, 1);
Constant zero(dataType, 0);
@@ -983,7 +940,7 @@ void raw_cudnn_convolution_backward_input_out(
args.handle,
&one, args.wdesc.desc(), weight.data_ptr(),
args.odesc.desc(), grad_output.data_ptr(),
- args.cdesc.desc(), bwdDataAlgPerf.algo, workspace.data, workspace.size,
+ args.cdesc.desc(), bwdDataAlg, workspace.data, workspace.size,
&zero, args.idesc.desc(), grad_input.data_ptr()));
}
@@ -1109,13 +1066,8 @@ void raw_cudnn_convolution_backward_weight_out(
args.odesc.set(grad_output);
args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
- cudnnConvolutionBwdFilterAlgoPerf_t bwdFilterAlgPerf;
- Workspace workspace = chooseAlgorithm(dataType, args, benchmark, &bwdFilterAlgPerf);
-
- // update convDesc mathType since cudnn now requires both algo + mathType to figure out
- // whether to use Tensor cores or not
- // See Note [chooseAlgorithm doesn't respect mathType]
- AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdFilterAlgPerf.mathType));
+ cudnnConvolutionBwdFilterAlgo_t bwdFilterAlg;
+ Workspace workspace = chooseAlgorithm(args, benchmark, &bwdFilterAlg);
Constant one(dataType, 1);
Constant zero(dataType, 0);
@@ -1124,7 +1076,7 @@ void raw_cudnn_convolution_backward_weight_out(
args.handle,
&one, args.idesc.desc(), input.data_ptr(),
args.odesc.desc(), grad_output.data_ptr(),
- args.cdesc.desc(), bwdFilterAlgPerf.algo, workspace.data, workspace.size,
+ args.cdesc.desc(), bwdFilterAlg, workspace.data, workspace.size,
&zero, args.wdesc.desc(), grad_weight.data_ptr()));
}