diff options
author | Jithun Nair <jithun.nair@amd.com> | 2019-01-31 14:00:00 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-31 14:09:04 -0800 |
commit | 4bdf51cbd60f03b2d4228c2fd576b04548b7cc26 (patch) | |
tree | 26a038b921697c077e4bc89b55a46156d6e775c9 | |
parent | a061e3fd77d2e3b11ba90b81e1e3e4a0c35c7403 (diff) | |
download | pytorch-4bdf51cbd60f03b2d4228c2fd576b04548b7cc26.tar.gz pytorch-4bdf51cbd60f03b2d4228c2fd576b04548b7cc26.tar.bz2 pytorch-4bdf51cbd60f03b2d4228c2fd576b04548b7cc26.zip |
Make the miopen handle part of ConvolutionParams (#16613)
Summary:
so that it's included in the hashed key that decides whether to call Find or not. This is required to ensure that Find is run for all devices
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16613
Differential Revision: D13901769
Pulled By: bddppq
fbshipit-source-id: 7d29ea9e40231cd4eef80847afa1307efeb0945c
-rw-r--r-- | aten/src/ATen/native/miopen/Conv_miopen.cpp | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/aten/src/ATen/native/miopen/Conv_miopen.cpp b/aten/src/ATen/native/miopen/Conv_miopen.cpp index 19b29800fc..360e522f83 100644 --- a/aten/src/ATen/native/miopen/Conv_miopen.cpp +++ b/aten/src/ATen/native/miopen/Conv_miopen.cpp @@ -224,6 +224,7 @@ static void convolution_shape_check( // parameters struct ConvolutionParams { + miopenHandle_t handle; miopenDataType_t dataType; int input_size[2 + max_dim]; int input_stride[2 + max_dim]; @@ -242,7 +243,7 @@ struct ConvolutionParams static_assert(std::is_pod<ConvolutionParams>::value, "ConvolutionParams not POD"); void setConvolutionParams( - ConvolutionParams* params, + ConvolutionParams* params, miopenHandle_t handle, const at::Tensor& input, const at::Tensor& weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool deterministic) { @@ -250,6 +251,7 @@ void setConvolutionParams( miopenDataType_t dataType = getMiopenDataType(input); memset(params, 0, sizeof(ConvolutionParams)); params->dataType = dataType; + params->handle = handle; // ASSERT(weight.dim() == input.dim()) for (int i = 0; i != input.dim(); ++i) { params->input_size[i] = (int) input.size(i); @@ -604,7 +606,7 @@ void raw_miopen_convolution_forward_out( ConvolutionArgs args{ input, output, weight }; args.handle = getMiopenHandle(); - setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic); + setConvolutionParams(&args.params, args.handle, input, weight, padding, stride, dilation, groups, deterministic); args.idesc.set(input); args.wdesc.set(weight); args.odesc.set(output); @@ -721,7 +723,7 @@ void raw_miopen_convolution_backward_input_out( ConvolutionArgs args{ grad_input, grad_output, weight }; args.handle = getMiopenHandle(); - setConvolutionParams(&args.params, grad_input, weight, padding, stride, dilation, groups, deterministic); + setConvolutionParams(&args.params, args.handle, grad_input, weight, padding, stride, dilation, groups, deterministic); args.idesc.set(grad_input); args.wdesc.set(weight); args.odesc.set(grad_output); @@ -847,7 +849,7 @@ void raw_miopen_convolution_backward_weight_out( ConvolutionArgs args{ input, grad_output, grad_weight }; args.handle = getMiopenHandle(); - setConvolutionParams(&args.params, input, grad_weight, padding, stride, dilation, groups, deterministic); + setConvolutionParams(&args.params, args.handle, input, grad_weight, padding, stride, dilation, groups, deterministic); args.idesc.set(input); args.wdesc.set(grad_weight); args.odesc.set(grad_output); |