summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJithun Nair <jithun.nair@amd.com>2019-01-31 14:00:00 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-31 14:09:04 -0800
commit4bdf51cbd60f03b2d4228c2fd576b04548b7cc26 (patch)
tree26a038b921697c077e4bc89b55a46156d6e775c9
parenta061e3fd77d2e3b11ba90b81e1e3e4a0c35c7403 (diff)
downloadpytorch-4bdf51cbd60f03b2d4228c2fd576b04548b7cc26.tar.gz
pytorch-4bdf51cbd60f03b2d4228c2fd576b04548b7cc26.tar.bz2
pytorch-4bdf51cbd60f03b2d4228c2fd576b04548b7cc26.zip
Make the miopen handle part of ConvolutionParams (#16613)
Summary: so that it's included in the hashed key that decides whether to call Find or not. This is required to ensure that Find is run for all devices Pull Request resolved: https://github.com/pytorch/pytorch/pull/16613 Differential Revision: D13901769 Pulled By: bddppq fbshipit-source-id: 7d29ea9e40231cd4eef80847afa1307efeb0945c
-rw-r--r--aten/src/ATen/native/miopen/Conv_miopen.cpp10
1 files changed, 6 insertions, 4 deletions
diff --git a/aten/src/ATen/native/miopen/Conv_miopen.cpp b/aten/src/ATen/native/miopen/Conv_miopen.cpp
index 19b29800fc..360e522f83 100644
--- a/aten/src/ATen/native/miopen/Conv_miopen.cpp
+++ b/aten/src/ATen/native/miopen/Conv_miopen.cpp
@@ -224,6 +224,7 @@ static void convolution_shape_check(
// parameters
struct ConvolutionParams
{
+ miopenHandle_t handle;
miopenDataType_t dataType;
int input_size[2 + max_dim];
int input_stride[2 + max_dim];
@@ -242,7 +243,7 @@ struct ConvolutionParams
static_assert(std::is_pod<ConvolutionParams>::value, "ConvolutionParams not POD");
void setConvolutionParams(
- ConvolutionParams* params,
+ ConvolutionParams* params, miopenHandle_t handle,
const at::Tensor& input, const at::Tensor& weight,
IntList padding, IntList stride, IntList dilation,
int64_t groups, bool deterministic) {
@@ -250,6 +251,7 @@ void setConvolutionParams(
miopenDataType_t dataType = getMiopenDataType(input);
memset(params, 0, sizeof(ConvolutionParams));
params->dataType = dataType;
+ params->handle = handle;
// ASSERT(weight.dim() == input.dim())
for (int i = 0; i != input.dim(); ++i) {
params->input_size[i] = (int) input.size(i);
@@ -604,7 +606,7 @@ void raw_miopen_convolution_forward_out(
ConvolutionArgs args{ input, output, weight };
args.handle = getMiopenHandle();
- setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic);
+ setConvolutionParams(&args.params, args.handle, input, weight, padding, stride, dilation, groups, deterministic);
args.idesc.set(input);
args.wdesc.set(weight);
args.odesc.set(output);
@@ -721,7 +723,7 @@ void raw_miopen_convolution_backward_input_out(
ConvolutionArgs args{ grad_input, grad_output, weight };
args.handle = getMiopenHandle();
- setConvolutionParams(&args.params, grad_input, weight, padding, stride, dilation, groups, deterministic);
+ setConvolutionParams(&args.params, args.handle, grad_input, weight, padding, stride, dilation, groups, deterministic);
args.idesc.set(grad_input);
args.wdesc.set(weight);
args.odesc.set(grad_output);
@@ -847,7 +849,7 @@ void raw_miopen_convolution_backward_weight_out(
ConvolutionArgs args{ input, grad_output, grad_weight };
args.handle = getMiopenHandle();
- setConvolutionParams(&args.params, input, grad_weight, padding, stride, dilation, groups, deterministic);
+ setConvolutionParams(&args.params, args.handle, input, grad_weight, padding, stride, dilation, groups, deterministic);
args.idesc.set(input);
args.wdesc.set(grad_weight);
args.odesc.set(grad_output);