diff options
author | Peter Yeh <petrex@users.noreply.github.com> | 2018-06-13 04:00:39 -0700 |
---|---|---|
committer | bddppq <bai@in.tum.de> | 2018-06-13 04:00:39 -0700 |
commit | c37e5b7137cb8c7a11f15a10185aac1f673b07db (patch) | |
tree | 953db0f50e9db46f4d84dba1b3632882596797ce /caffe2/operators/hip | |
parent | 36bf89bf09469fdad8661c3930ca31875f9cdbf3 (diff) | |
download | pytorch-c37e5b7137cb8c7a11f15a10185aac1f673b07db.tar.gz pytorch-c37e5b7137cb8c7a11f15a10185aac1f673b07db.tar.bz2 pytorch-c37e5b7137cb8c7a11f15a10185aac1f673b07db.zip |
[Caffe2] Enable AMD/MIOPEN ops for Caffe2 (#8306)
* Add hip support for caffe2 core
* Add MIOPEN header/wrapper to caffe2 core
* Add HIP device into caffe2 PB
* top level makefile change for rocm/hip
* makefile scaffolding for AMD/RocM/HIP
* Makefile scafodding for AMD/RocM/HIP; add makefile/utility for HIP files
* caffe2 PB update for AMD/ROCM HIP device
* Add AMD/RocM/Thrust dependency
* HIP threadpool update
* Fix makefile macro
* makefile fix: duplicate test/binary name
* makefile clean-up
* makefile clean-up
* add HIP operator registry
* add utilities for hip device
* Add USE_HIP to config summary
* makefile fix for BUILD_TEST
* merge latest
* Fix indentation
* code clean-up
* Guard builds without HIP and use the same cmake script as PyTorch to find HIP
* Setup rocm environment variables in build.sh (ideally should be done in the docker images)
* setup locale
* set HIP_PLATFORM
* Revert "set HIP_PLATFORM"
This reverts commit 8ec58db2b390c9259220c49fa34cd403568300ad.
* continue the build script environment variables mess
* HCC_AMDGPU_TARGET
* Cleanup the mess, has been fixed in the lastest docker images
* Assign protobuf field hip_gpu_id a new field number for backward compatibility
* change name to avoid conflict
* Fix duplicated thread pool flag
* Refactor cmake files to not add hip includes and libs globally
* Fix the wrong usage of environment variables detection in cmake
* Add MIOPEN CNN operators
* Revert "Add MIOPEN CNN operators"
This reverts commit 6e89ad4385b5b8967a7854c4adda52c012cee42a.
* Add MIOPEN pooling operator
* Add MIOPEN activation operator
* Add MIOPEN softmax operator
* Add MIOPEN spatial batch norm operator
* Add MIOPEN loacl response normalization operator
* Add MIOPEN conv operator
* Clean-up LRN ops
* enable fp16 in MIOPEN pool ops
* Enable fp16 for MIOPEN relu op
* Enable fp16 for MIOPEN spatial batch norm op
* code clean-up
* revert float16 support
* Create Caffe2 python binding for AMD/ROCM/HIP
* Add op fallback for HIP operator
* add hip src/test files in cmake
* exclude hip src/test files
* fix python binding for hip backend
* fix MIOPEN pooling op workspace
* hack to compile miopen operators
* fix include path for MIOPEN ops
* Fix include path
* Add HIP math utilities
* Fix path for HIP math utils
* cmake fix
* Cmake fix / hipcc for hip files
* suppress hipcc warning
* cmake fix /replcae USE_HIP with USE_ROCM
* revert LoadHIP.cmake change
* fix include for thrust/cub-hip
* include path fix for conversion.h
* Updated with latest upstream changes
* clang format fixes
* Context_hip updates
* Fixed typo in rocblas handle get function
* Updated hipified math utils
* Updated math hip test util
* Updated context hip test
* Updated common_hip
* Updated net async dag for HIP
* Added MIOPEN in operator hip test
* fix
* C2 dependencies clean-up
* fix include path for building custom protobuf
* Decouple miopen pool op and conv_pool_op base
* cmake refactor
* fix operator_hip_test
* move all hip/miopen ops files into caffe2/operators/hip
* sanitize cmake
* permission issue
* remove extra parenthesis
* remove artifact from resolving merge conflict
* cont. sanitize cmake files
* fix syntax error
* sanitize conversion.h
* .
* Revert "."
This reverts commit 56020cb0e996a31ae27bf1f8f491955ed0b121b9.
* clang-format
Diffstat (limited to 'caffe2/operators/hip')
-rw-r--r-- | caffe2/operators/hip/conv_op_miopen.cc | 859 | ||||
-rw-r--r-- | caffe2/operators/hip/local_response_normalization_op_miopen.cc | 248 | ||||
-rw-r--r-- | caffe2/operators/hip/operator_fallback_hip.h | 114 | ||||
-rw-r--r-- | caffe2/operators/hip/operator_fallback_hip_test.cc | 80 | ||||
-rw-r--r-- | caffe2/operators/hip/pool_op_miopen.cc | 310 | ||||
-rw-r--r-- | caffe2/operators/hip/relu_op_miopen.cc | 205 | ||||
-rw-r--r-- | caffe2/operators/hip/softmax_op_miopen.cc | 138 | ||||
-rw-r--r-- | caffe2/operators/hip/spatial_batch_norm_op_miopen.cc | 318 |
8 files changed, 2272 insertions, 0 deletions
diff --git a/caffe2/operators/hip/conv_op_miopen.cc b/caffe2/operators/hip/conv_op_miopen.cc new file mode 100644 index 0000000000..84d9a75cfa --- /dev/null +++ b/caffe2/operators/hip/conv_op_miopen.cc @@ -0,0 +1,859 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "caffe2/core/hip/context_hip.h" +#include "caffe2/core/hip/miopen_wrapper.h" +#include "caffe2/operators/conv_op.h" +#include "caffe2/operators/conv_pool_op_base.h" + +namespace caffe2 { + +// Earlier in the days Caffe sets the default miopen workspace to 8MB. We bump +// it up to 64MB in Caffe2, as this enables the use of Winograd in many cases, +// something very beneficial to more recent CNN models. +static constexpr size_t kCONV_MIOPEN_WORKSPACE_LIMIT_BYTES = 64 * 1024 * 1024; + +class MIOPENConvOpBase : public ConvPoolOpBase<HIPContext> { + public: + MIOPENConvOpBase(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase<HIPContext>(operator_def, ws), + miopen_wrapper_(&context_), + miopen_ws_nbytes_limit_(OperatorBase::GetSingleArgument<size_t>( + "ws_nbytes_limit", + kCONV_MIOPEN_WORKSPACE_LIMIT_BYTES)), + alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)), + beta_(OperatorBase::GetSingleArgument<float>("beta", 0.0)), + exhaustive_search_( + OperatorBase::GetSingleArgument<bool>("exhaustive_search", false)) { + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&bottom_desc_)); + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&bias_desc_)); + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&weight_desc_)); + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&top_desc_)); + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&top_desc_for_bias_)); + MIOPEN_ENFORCE(miopenCreateConvolutionDescriptor(&conv_desc_)); + + if ((operator_def.type().substr(0, 6) == "Conv") || + (operator_def.type().substr(0, 14) == "ConvGradient")) { + mode_ = miopenConvolution; + } else if ( + (operator_def.type().substr(0, 7) == "Trans") || + (operator_def.type().substr(0, 15) == "TransGradient")) { + mode_ = miopenTranspose; + } else { + LOG(FATAL) << "Unsupported convolution method: " << operator_def.type(); + } + + MIOPEN_ENFORCE(miopenInitConvolutionDescriptor( + conv_desc_, + mode_, + pad_t(), + pad_l(), + stride_h(), + stride_w(), + dilation_h(), + dilation_w())); + } + + ~MIOPENConvOpBase() { + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(bottom_desc_)); + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(bias_desc_)); + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(weight_desc_)); + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(top_desc_)); + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(top_desc_for_bias_)); + MIOPEN_ENFORCE(miopenDestroyConvolutionDescriptor(conv_desc_)); + } + + protected: + MIOPENWrapper miopen_wrapper_; + miopenTensorDescriptor_t bottom_desc_; + miopenTensorDescriptor_t bias_desc_; + miopenTensorDescriptor_t weight_desc_; + miopenTensorDescriptor_t top_desc_; + miopenTensorDescriptor_t top_desc_for_bias_; + miopenConvolutionDescriptor_t conv_desc_; + miopenConvolutionMode_t mode_; + const size_t miopen_ws_nbytes_limit_; + bool exhaustive_search_; + const float alpha_; + const float beta_; +}; + +class MIOPENConvOp final : public MIOPENConvOpBase { + public: + MIOPENConvOp(const OperatorDef& operator_def, Workspace* ws) + : MIOPENConvOpBase(operator_def, ws), + requestAlgoCount_( + OperatorBase::GetSingleArgument<int>("requestAlgoCount_", 1)), + returnedAlgoCount_( + OperatorBase::GetSingleArgument<int>("returnedAlgoCount_", 1)), + bestAlgoFound_( + OperatorBase::GetSingleArgument<bool>("bestAlgoFound_", false)), + fwdConvWs_(nullptr), + fwdConvWsSize_(0), + fwdAlgo_(miopenConvolutionFwdAlgoGEMM) {} + + ~MIOPENConvOp() { + if (fwdConvWs_) { + hipFree(fwdConvWs_); + fwdConvWs_ = nullptr; + fwdConvWsSize_ = 0; + } + } + + template < + typename T_X, + typename T_W, + typename T_B, + typename MATH, + typename T_Y> + bool DoRunWithType(); + bool RunOnDevice() override; + + private: + const int requestAlgoCount_; + int returnedAlgoCount_; + bool bestAlgoFound_; + char* fwdConvWs_; + size_t fwdConvWsSize_; + miopenConvFwdAlgorithm_t fwdAlgo_; + // Input: X, W, b + // Output: Y + INPUT_TAGS(INPUT, FILTER, BIAS); +}; + +class MIOPENConvGradientOp final : public MIOPENConvOpBase { + public: + MIOPENConvGradientOp(const OperatorDef& operator_def, Workspace* ws) + : MIOPENConvOpBase(operator_def, ws), + no_bias_(OperatorBase::GetSingleArgument<int>("no_bias", 0)), + requestAlgoCount_( + OperatorBase::GetSingleArgument<int>("requestAlgoCount_", 1)), + returnedAlgoCount_( + OperatorBase::GetSingleArgument<int>("returnedAlgoCount_", 1)), + bestDataAlgoFound_( + OperatorBase::GetSingleArgument<bool>("bestAlgoFound", false)), + bestWeightAlgoFound_( + OperatorBase::GetSingleArgument<bool>("bestAlgoFound", false)), + bwdWeightWs_(nullptr), + bwdWeightWsSize_(0), + bwdDataWs_(nullptr), + bwdDataWsSize_(0), + bwdWeiAlgo_(miopenConvolutionBwdWeightsAlgoGEMM), + bwdDataAlgo_(miopenConvolutionBwdDataAlgoGEMM) { + OPERATOR_NEEDS_FEATURE( + group_ == 1, + "Group convolution not supported yet for MIOpen ConvGradient."); + CAFFE_ENFORCE( + !(no_bias_ && OutputSize() == 3), + "If bias is not present, you should not have 3 grad output."); + } + + ~MIOPENConvGradientOp() { + if (bwdWeightWs_) { + hipFree(bwdWeightWs_); + bwdWeightWs_ = nullptr; + bwdWeightWsSize_ = 0; + } + if (bwdDataWs_) { + hipFree(bwdDataWs_); + bwdDataWs_ = nullptr; + bwdDataWsSize_ = 0; + } + } + + template < + typename T_X, + typename T_DY, + typename T_W, + typename T_B, + typename MATH, + typename T_DX, + typename T_DW, + typename T_DB> + bool DoRunWithType(); + bool RunOnDevice() override; + + private: + bool no_bias_; + const int requestAlgoCount_; + int returnedAlgoCount_; + bool bestDataAlgoFound_; + bool bestWeightAlgoFound_; + miopenConvBwdWeightsAlgorithm_t bwdWeiAlgo_; + miopenConvBwdDataAlgorithm_t bwdDataAlgo_; + size_t bwdWeightWsSize_; + size_t bwdDataWsSize_; + char* bwdWeightWs_; + char* bwdDataWs_; + // input: X, W, dY + // output: dW, db, and optionally dX + INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD); + OUTPUT_TAGS(FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD); +}; + +//////////////////////////////////////////////////////////////////////////////// +// Implementations +//////////////////////////////////////////////////////////////////////////////// + +template <typename T_X, typename T_W, typename T_B, typename MATH, typename T_Y> +bool MIOPENConvOp::DoRunWithType() { + auto& X = Input(INPUT); + auto& Weight = Input(FILTER); + auto* Y = Output(0); + + // Figure out the output shape + CAFFE_ENFORCE(X.ndim() >= 3 && X.ndim() <= 5); + CAFFE_ENFORCE( + Weight.ndim() == 4, + "Conv/Trans op with MIOpen engine is supported only for 2D convolutions"); + + const int M = Weight.dim32(0); + ConvPoolOpBase<HIPContext>::SetOutputSize(X, Y, M); + + int N = X.dim32(0); + int C = X.dim32(1); + int H = X.dim32(2); + int W = X.ndim() > 3 ? X.dim32(3) : 1; + int D = X.ndim() > 4 ? X.dim32(4) : 1; + + int N_out = Y->dim32(0); + int C_out = Y->dim32(1); + int H_out = Y->dim32(2); + int W_out = Y->ndim() > 3 ? Y->dim32(3) : 1; + int D_out = Y->ndim() > 4 ? Y->dim32(4) : 1; + CAFFE_ENFORCE_EQ(Weight.dim32(1), C / group_); + + CAFFE_ENFORCE( + C % group_ == 0, + "If you set group, the number of input channels should be divisible " + "by group."); + CAFFE_ENFORCE( + M % group_ == 0, + "If you set group, the number of output channels should be divisible " + "by group."); + + if (group_ > 1) { + int group_offset_filter = Weight.size() / group_; + + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + weight_desc_, + miopenTypeWrapper<T_W>::type, + M / group_, + C / group_, + kernel_h(), + kernel_w())); + + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + bottom_desc_, miopenTypeWrapper<T_X>::type, 1, C / group_, H, W)); + + MIOPEN_ENFORCE(miopenGetConvolutionForwardOutputDim( + conv_desc_, + bottom_desc_, + weight_desc_, + &N_out, + &C_out, + &H_out, + &W_out)); + + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out)); + + if (InputSize() == 3) { + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + bias_desc_, miopenTypeWrapper<T_B>::type, 1, Y->dim32(1), 1, 1)); + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + top_desc_for_bias_, + miopenTypeWrapper<T_X>::type, + Y->dim32(0), + Y->dim32(1), + H_out, + W_out)); + } + + MIOPEN_ENFORCE(miopenConvolutionForwardGetWorkSpaceSize( + miopen_wrapper_.inline_miopen_handle(), + weight_desc_, + bottom_desc_, + conv_desc_, + top_desc_, + &fwdConvWsSize_)); + + int group_offset_X = C / group_ * H * W * D; + int batch_offset_X = group_offset_X * group_; + int group_offset_Y = M / group_ * H_out * W_out * D_out; + int batch_offset_Y = group_offset_Y * group_; + + if ((fwdConvWsSize_ > 0) && (fwdConvWs_ == nullptr)) { + HIP_CHECK(hipMalloc(&fwdConvWs_, fwdConvWsSize_)); + } + + while (!bestAlgoFound_) { + miopenConvAlgoPerf_t perf; + MIOPEN_ENFORCE(miopenFindConvolutionForwardAlgorithm( + miopen_wrapper_.inline_miopen_handle(), + bottom_desc_, + X.template data<T_X>(), + weight_desc_, + Weight.template data<T_W>(), + conv_desc_, + top_desc_, + Y->template mutable_data<T_Y>(), + requestAlgoCount_, + &returnedAlgoCount_, + &perf, + fwdConvWs_, + fwdConvWsSize_, + false)); + bestAlgoFound_ = true; + fwdAlgo_ = perf.fwd_algo; + } + + for (int b = 0; b < N; b++) { + for (int g = 0; g < group_; g++) { + MIOPEN_ENFORCE(miopenConvolutionForward( + miopen_wrapper_.inline_miopen_handle(), + &alpha_, + bottom_desc_, + X.template data<T_X>() + (b * batch_offset_X) + + (g * group_offset_X), + weight_desc_, + Weight.template data<T_W>() + g * group_offset_filter, + conv_desc_, + fwdAlgo_, + &beta_, + top_desc_, + Y->template mutable_data<T_Y>() + (b * batch_offset_Y) + + (g * group_offset_Y), + fwdConvWs_, + fwdConvWsSize_)); + } + } + hipDeviceSynchronize(); + + // BIAS + if (InputSize() == 3) { + auto& bias = Input(BIAS); + + CAFFE_ENFORCE_EQ(bias.ndim(), 1); + CAFFE_ENFORCE_EQ(bias.dim32(0), M); + MIOPEN_ENFORCE(miopenConvolutionForwardBias( + miopen_wrapper_.inline_miopen_handle(), + &alpha_, + bias_desc_, + bias.template data<T_B>(), + &beta_, + top_desc_for_bias_, + Y->template mutable_data<T_Y>())); + } + + hipDeviceSynchronize(); + } else { + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + weight_desc_, + miopenTypeWrapper<T_W>::type, + M, + C, + kernel_h(), + kernel_w())); + + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W)); + + MIOPEN_ENFORCE(miopenGetConvolutionForwardOutputDim( + conv_desc_, + bottom_desc_, + weight_desc_, + &N_out, + &C_out, + &H_out, + &W_out)); + + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out)); + + if (InputSize() == 3) { + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + bias_desc_, miopenTypeWrapper<T_B>::type, 1, C_out, 1, 1)); + } + + MIOPEN_ENFORCE(miopenConvolutionForwardGetWorkSpaceSize( + miopen_wrapper_.inline_miopen_handle(), + weight_desc_, + bottom_desc_, + conv_desc_, + top_desc_, + &fwdConvWsSize_)); + + if ((fwdConvWsSize_ > 0) && (fwdConvWs_ == nullptr)) { + HIP_CHECK(hipMalloc(&fwdConvWs_, fwdConvWsSize_)); + } + + while (!bestAlgoFound_) { + miopenConvAlgoPerf_t perf; + MIOPEN_ENFORCE(miopenFindConvolutionForwardAlgorithm( + miopen_wrapper_.inline_miopen_handle(), + bottom_desc_, + X.template data<T_X>(), + weight_desc_, + Weight.template data<T_W>(), + conv_desc_, + top_desc_, + Y->template mutable_data<T_Y>(), + requestAlgoCount_, + &returnedAlgoCount_, + &perf, + fwdConvWs_, + fwdConvWsSize_, + false)); + bestAlgoFound_ = true; + fwdAlgo_ = perf.fwd_algo; + } + MIOPEN_ENFORCE(miopenConvolutionForward( + miopen_wrapper_.inline_miopen_handle(), + &alpha_, + bottom_desc_, + X.template data<T_X>(), + weight_desc_, + Weight.template data<T_W>(), + conv_desc_, + fwdAlgo_, + &beta_, + top_desc_, + Y->template mutable_data<T_Y>(), + fwdConvWs_, + fwdConvWsSize_)); + + // BIAS + if (InputSize() == 3) { + auto& bias = Input(BIAS); + + CAFFE_ENFORCE_EQ(bias.ndim(), 1); + CAFFE_ENFORCE_EQ(bias.dim32(0), M); + MIOPEN_ENFORCE(miopenConvolutionForwardBias( + miopen_wrapper_.inline_miopen_handle(), + &alpha_, + bias_desc_, + bias.template data<T_B>(), + &beta_, + top_desc_, + Y->template mutable_data<T_Y>())); + } + + hipDeviceSynchronize(); + } + + return true; +} +// TODO : enable fp16 support. +bool MIOPENConvOp::RunOnDevice() { + if (Input(0).IsType<float>()) { + return DoRunWithType< + float, // X + float, // W + float, // B + float, // Math + float>(); // Y + } else { + LOG(FATAL) << "Only float (32bit) is supported by " + << "miopen convolution, but input " << debug_def().input(0) + << " has [" << Input(0).meta().name() << "]"; + } + return true; +} + +template < + typename T_X, + typename T_DY, + typename T_W, + typename T_B, + typename MATH, + typename T_DX, + typename T_DW, + typename T_DB> +bool MIOPENConvGradientOp::DoRunWithType() { + auto& X = Input(INPUT); + auto& Weight = Input(FILTER); + auto& dY = Input(OUTPUT_GRAD); + auto* dW = Output(FILTER_GRAD); + auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD); + dX->ResizeLike(X); + dW->ResizeLike(Weight); + + CAFFE_ENFORCE(X.ndim() >= 3 && X.ndim() <= 5); + CAFFE_ENFORCE( + Weight.ndim() == 4, + "ConvGradient/TransGradient op with MIOpen engine is supported only for 2D convolutions"); + + const int M = Weight.dim32(0); + int N = 0, C = 0, H = 0, W = 0, D = 0, N_out = 0, C_out = 0, H_out = 0, + W_out = 0, D_out = 0; + + N = X.dim32(0); + C = X.dim32(1); + H = X.dim32(2); + W = X.ndim() > 3 ? X.dim32(3) : 1; + D = X.ndim() > 4 ? X.dim32(4) : 1; + + N_out = dY.dim32(0); + C_out = dY.dim32(1); + H_out = dY.dim32(2); + W_out = dY.ndim() > 3 ? dY.dim32(3) : 1; + D_out = dY.ndim() > 4 ? dY.dim32(4) : 1; + + CAFFE_ENFORCE_EQ(Weight.dim32(1), C / group_); + + CAFFE_ENFORCE( + C % group_ == 0, + "If you set group, the number of input channels should be divisible " + "by group."); + CAFFE_ENFORCE( + M % group_ == 0, + "If you set group, the number of output channels should be divisible " + "by group."); + + if (group_ > 1) { + int group_offset_filter = Weight.size() / group_; + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + weight_desc_, + miopenTypeWrapper<T_X>::type, + M / group_, + C / group_, + kernel_h(), + kernel_w())); + + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + bottom_desc_, miopenTypeWrapper<T_X>::type, 1, C / group_, H, W)); + + MIOPEN_ENFORCE(miopenGetConvolutionForwardOutputDim( + conv_desc_, + bottom_desc_, + weight_desc_, + &N_out, + &C_out, + &H_out, + &W_out)); + + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out)); + + if (!no_bias_) { + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + bias_desc_, miopenTypeWrapper<T_B>::type, 1, M, 1, 1)); + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + top_desc_for_bias_, + miopenTypeWrapper<T_X>::type, + dY.dim32(0), + M, + H_out, + W_out)); + } + + MIOPEN_ENFORCE(miopenConvolutionBackwardDataGetWorkSpaceSize( + miopen_wrapper_.inline_miopen_handle(), + top_desc_, + weight_desc_, + conv_desc_, + bottom_desc_, + &bwdDataWsSize_)); + + int group_offset_X = C / group_ * H * W * D; + int batch_offset_X = group_offset_X * group_; + int group_offset_Y = M / group_ * H_out * W_out * D_out; + int batch_offset_Y = group_offset_Y * group_; + + if ((bwdDataWsSize_ > 0) && (bwdDataWs_ == nullptr)) { + HIP_CHECK(hipMalloc(&bwdDataWs_, bwdDataWsSize_)); + } + + MIOPEN_ENFORCE(miopenConvolutionBackwardWeightsGetWorkSpaceSize( + miopen_wrapper_.inline_miopen_handle(), + top_desc_, + bottom_desc_, + conv_desc_, + weight_desc_, + &bwdWeightWsSize_)); + + if ((bwdWeightWsSize_ > 0) && (bwdWeightWs_ == nullptr)) { + HIP_CHECK(hipMalloc(&bwdWeightWs_, bwdWeightWsSize_)); + } + + while (!bestDataAlgoFound_) { + miopenConvAlgoPerf_t perf; + MIOPEN_ENFORCE(miopenFindConvolutionBackwardDataAlgorithm( + miopen_wrapper_.inline_miopen_handle(), + top_desc_, + dY.template data<T_DY>(), + weight_desc_, + Weight.template data<T_W>(), + conv_desc_, + bottom_desc_, + dX->template mutable_data<T_DX>(), + requestAlgoCount_, + &returnedAlgoCount_, + &perf, + bwdDataWs_, + bwdDataWsSize_, + false)); + + bestDataAlgoFound_ = true; + bwdDataAlgo_ = perf.bwd_data_algo; + } + + while (!bestWeightAlgoFound_) { + miopenConvAlgoPerf_t perf; + MIOPEN_ENFORCE(miopenFindConvolutionBackwardWeightsAlgorithm( + miopen_wrapper_.inline_miopen_handle(), + top_desc_, + dY.template data<T_DY>(), + bottom_desc_, + X.template data<T_X>(), + conv_desc_, + weight_desc_, + dW->template mutable_data<T_DW>(), + requestAlgoCount_, + &returnedAlgoCount_, + &perf, + bwdWeightWs_, + bwdWeightWsSize_, + false)); + bestWeightAlgoFound_ = true; + bwdWeiAlgo_ = perf.bwd_weights_algo; + } + + for (int b = 0; b < N; b++) { + for (int g = 0; g < group_; g++) { + MIOPEN_ENFORCE(miopenConvolutionBackwardData( + miopen_wrapper_.inline_miopen_handle(), + &alpha_, + top_desc_, + dY.template data<T_DY>() + (b * batch_offset_Y) + + (g * group_offset_Y), + weight_desc_, + Weight.template data<T_W>() + g * group_offset_filter, + conv_desc_, + bwdDataAlgo_, + &beta_, + bottom_desc_, + dX->template mutable_data<T_DX>() + (b * batch_offset_X) + + (g * group_offset_X), + bwdDataWs_, + bwdDataWsSize_)); + + MIOPEN_ENFORCE(miopenConvolutionBackwardWeights( + miopen_wrapper_.inline_miopen_handle(), + &alpha_, + top_desc_, + dY.template data<T_DY>() + (b * batch_offset_Y) + + (g * group_offset_Y), + bottom_desc_, + X.template data<T_X>() + (b * batch_offset_X) + + (g * group_offset_X), + conv_desc_, + bwdWeiAlgo_, + &beta_, + weight_desc_, + dW->template mutable_data<T_DW>() + g * group_offset_filter, + bwdWeightWs_, + bwdWeightWsSize_)); + } + } + + // Synchronize the work across groups. + hipDeviceSynchronize(); + + ////////////////////////////////////// BIAS /////////////////////////// + if (!no_bias_) { + auto* dbias = Output(BIAS_OR_INPUT_GRAD); + dbias->Resize(M); + MIOPEN_ENFORCE(miopenConvolutionBackwardBias( + miopen_wrapper_.inline_miopen_handle(), + &alpha_, + top_desc_for_bias_, + dY.template data<T_DY>(), + &beta_, + bias_desc_, + dbias->template mutable_data<T_DB>())); + } + } else // No group + { + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + weight_desc_, + miopenTypeWrapper<T_X>::type, + M, + C, + kernel_h(), + kernel_w())); + + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W)); + + MIOPEN_ENFORCE(miopenGetConvolutionForwardOutputDim( + conv_desc_, + bottom_desc_, + weight_desc_, + &N_out, + &C_out, + &H_out, + &W_out)); + + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out)); + + if (!no_bias_) { + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + bias_desc_, miopenTypeWrapper<T_B>::type, 1, M, 1, 1)); + } + + MIOPEN_ENFORCE(miopenConvolutionBackwardDataGetWorkSpaceSize( + miopen_wrapper_.inline_miopen_handle(), + top_desc_, + weight_desc_, + conv_desc_, + bottom_desc_, + &bwdDataWsSize_)); + + if ((bwdDataWsSize_ > 0) && (bwdDataWs_ == nullptr)) { + HIP_CHECK(hipMalloc(&bwdDataWs_, bwdDataWsSize_)); + } + + MIOPEN_ENFORCE(miopenConvolutionBackwardWeightsGetWorkSpaceSize( + miopen_wrapper_.inline_miopen_handle(), + top_desc_, + bottom_desc_, + conv_desc_, + weight_desc_, + &bwdWeightWsSize_)); + + if ((bwdWeightWsSize_ > 0) && (bwdWeightWs_ == nullptr)) { + HIP_CHECK(hipMalloc(&bwdWeightWs_, bwdWeightWsSize_)); + } + + while (!bestDataAlgoFound_) { + miopenConvAlgoPerf_t perf; + MIOPEN_ENFORCE(miopenFindConvolutionBackwardDataAlgorithm( + miopen_wrapper_.inline_miopen_handle(), + top_desc_, + dY.template data<T_DY>(), + weight_desc_, + Weight.template data<T_W>(), + conv_desc_, + bottom_desc_, + dX->template mutable_data<T_DX>(), + requestAlgoCount_, + &returnedAlgoCount_, + &perf, + bwdDataWs_, + bwdDataWsSize_, + false)); + + bestDataAlgoFound_ = true; + bwdDataAlgo_ = perf.bwd_data_algo; + } + + while (!bestWeightAlgoFound_) { + miopenConvAlgoPerf_t perf; + MIOPEN_ENFORCE(miopenFindConvolutionBackwardWeightsAlgorithm( + miopen_wrapper_.inline_miopen_handle(), + top_desc_, + dY.template data<T_DY>(), + bottom_desc_, + X.template data<T_X>(), + conv_desc_, + weight_desc_, + dW->template mutable_data<T_DW>(), + requestAlgoCount_, + &returnedAlgoCount_, + &perf, + bwdWeightWs_, + bwdWeightWsSize_, + false)); + bestWeightAlgoFound_ = true; + bwdWeiAlgo_ = perf.bwd_weights_algo; + } + + MIOPEN_ENFORCE(miopenConvolutionBackwardData( + miopen_wrapper_.inline_miopen_handle(), + &alpha_, + top_desc_, + dY.template data<T_DY>(), + weight_desc_, + Weight.template data<T_W>(), + conv_desc_, + bwdDataAlgo_, + &beta_, + bottom_desc_, + dX->template mutable_data<T_DX>(), + bwdDataWs_, + bwdDataWsSize_)); + + MIOPEN_ENFORCE(miopenConvolutionBackwardWeights( + miopen_wrapper_.inline_miopen_handle(), + &alpha_, + top_desc_, + dY.template data<T_DY>(), + bottom_desc_, + X.template data<T_X>(), + conv_desc_, + bwdWeiAlgo_, + &beta_, + weight_desc_, + dW->template mutable_data<T_DW>(), + bwdWeightWs_, + bwdWeightWsSize_)); + + // Synchronize the work across groups. + hipDeviceSynchronize(); + + ////////////////////////////////////// BIAS /////////////////////////// + if (!no_bias_) { + auto* dbias = Output(BIAS_OR_INPUT_GRAD); + dbias->Resize(M); + MIOPEN_ENFORCE(miopenConvolutionBackwardBias( + miopen_wrapper_.inline_miopen_handle(), + &alpha_, + top_desc_, + dY.template data<T_DY>(), + &beta_, + bias_desc_, + dbias->template mutable_data<T_DB>())); + } + } + + return true; +} + +bool MIOPENConvGradientOp::RunOnDevice() { + if (Input(0).IsType<float>()) { + return DoRunWithType< + float, // X + float, // dY + float, // W + float, // b + float, // Math + float, // dX + float, // dW + float>(); // db + } else { + LOG(FATAL) << "Unsupported input types"; + } + return true; +} + +REGISTER_MIOPEN_OPERATOR(Conv, MIOPENConvOp); +REGISTER_MIOPEN_OPERATOR(ConvGradient, MIOPENConvGradientOp); +REGISTER_MIOPEN_OPERATOR(Trans, MIOPENConvOp); +REGISTER_MIOPEN_OPERATOR(TransGradient, MIOPENConvGradientOp); +} // namespace caffe2 diff --git a/caffe2/operators/hip/local_response_normalization_op_miopen.cc b/caffe2/operators/hip/local_response_normalization_op_miopen.cc new file mode 100644 index 0000000000..26da9bf5b8 --- /dev/null +++ b/caffe2/operators/hip/local_response_normalization_op_miopen.cc @@ -0,0 +1,248 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "caffe2/core/hip/context_hip.h" +#include "caffe2/core/hip/miopen_wrapper.h" +#include "caffe2/core/operator.h" +#include "caffe2/core/types.h" + +namespace caffe2 { + +class MIOPEN_LRNOP final : public Operator<HIPContext> { + public: + USE_OPERATOR_FUNCTIONS(HIPContext); + + MIOPEN_LRNOP(const OperatorDef& operator_def, Workspace* ws) + : Operator<HIPContext>(operator_def, ws), + miopen_wrapper_(&context_), + mode_(miopenLRNCrossChannel), + size_(OperatorBase::GetSingleArgument<int>("size", 0)), + alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)), + beta_(OperatorBase::GetSingleArgument<float>("beta", 0)), + bias_(OperatorBase::GetSingleArgument<float>("bias", 1)) { + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&data_desc_)); + MIOPEN_ENFORCE(miopenCreateLRNDescriptor(&norm_desc_)); + MIOPEN_ENFORCE( + miopenSetLRNDescriptor(norm_desc_, mode_, size_, alpha_, beta_, bias_)); + } + + ~MIOPEN_LRNOP() { + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(data_desc_)); + MIOPEN_ENFORCE(miopenDestroyLRNDescriptor(norm_desc_)); + } + + template <typename T, typename M> + bool DoRunWithType(); + bool RunOnDevice() override; + + protected: + MIOPENWrapper miopen_wrapper_; + miopenTensorDescriptor_t data_desc_; + miopenLRNDescriptor_t norm_desc_; + vector<TIndex> miopen_input_dims_; + const miopenLRNMode_t mode_; + const int size_; + const float alpha_; + const float beta_; + const float bias_; + // Input: X, Output: Y +}; + +class MIOPENLRNGradientOp final : public Operator<HIPContext> { + public: + USE_OPERATOR_FUNCTIONS(HIPContext); + MIOPENLRNGradientOp(const OperatorDef& operator_def, Workspace* ws) + : Operator<HIPContext>(operator_def, ws), + miopen_wrapper_(&context_), + mode_(miopenLRNCrossChannel), + size_(OperatorBase::GetSingleArgument<int>("size", 0)), + alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)), + beta_(OperatorBase::GetSingleArgument<float>("beta", 0)), + bias_(OperatorBase::GetSingleArgument<float>("bias", 1)), + do_backward_( + OperatorBase::GetSingleArgument<bool>("do_backward", false)), + bwdLRNWs_(nullptr), + bwdLRNScratch_(nullptr) { + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&data_desc_)); + MIOPEN_ENFORCE(miopenCreateLRNDescriptor(&norm_desc_)); + MIOPEN_ENFORCE( + miopenSetLRNDescriptor(norm_desc_, mode_, size_, alpha_, beta_, bias_)); + } + + ~MIOPENLRNGradientOp() { + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(data_desc_)); + MIOPEN_ENFORCE(miopenDestroyLRNDescriptor(norm_desc_)); + + if (bwdLRNWs_) { + hipFree(bwdLRNWs_); + bwdLRNWs_ = nullptr; + } + if (bwdLRNScratch_) { + hipFree(bwdLRNScratch_); + bwdLRNScratch_ = nullptr; + } + } + + template <typename T, typename M> + bool DoRunWithType(); + bool RunOnDevice() override; + + protected: + MIOPENWrapper miopen_wrapper_; + miopenTensorDescriptor_t data_desc_; + miopenLRNDescriptor_t norm_desc_; + vector<TIndex> miopen_input_dims_; + const miopenLRNMode_t mode_; + const int size_; + const float alpha_; + const float beta_; + const float bias_; + const bool do_backward_; + float* bwdLRNWs_; + float* bwdLRNScratch_; + // Input: X, Y, dY + // Output: dX +}; + +template <typename T, typename M> +bool MIOPEN_LRNOP::DoRunWithType() { + const auto& X = Input(0); + auto* Y = Output(0); + + // Reshape tensor descriptors if necessary + if (X.dims() != miopen_input_dims_) { + VLOG(1) << "Setting descriptors"; + miopen_input_dims_ = X.dims(); + int C = 1, H = 1, W = 1; + // Normal 4-dimensional tensors for images. + C = X.dim32(1); + H = X.dim32(2); + W = X.dim32(3); + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + data_desc_, miopenTypeWrapper<T>::type, X.dim32(0), C, H, W)); + } + + // now actually run the computation + MIOPEN_ENFORCE(miopenLRNForward( + miopen_wrapper_.inline_miopen_handle(), + norm_desc_, + &alpha_, + data_desc_, + X.template data<T>(), + &beta_, + data_desc_, + Y->template mutable_data<T>(), + false, + nullptr)); + + return true; +} + +bool MIOPEN_LRNOP::RunOnDevice() { + // dispatch based on contents of tensor(s) + const auto& X = Input(0); + auto* Y = Output(0); + Y->ResizeLike(X); + + if (X.IsType<float>()) { + return DoRunWithType<float, float>(); + } else { + CAFFE_THROW("Unsupported input type"); + } + return false; +} + +template <typename T, typename M> +bool MIOPENLRNGradientOp::DoRunWithType() { + const auto& X = Input(0); + const auto& Y = Input(1); + const auto& dY = Input(2); + auto* dX = Output(0); + + if (dY.dims() != miopen_input_dims_) { + VLOG(1) << "Setting descriptors"; + miopen_input_dims_ = dY.dims(); + int C = 1, H = 1, W = 1; + // Normal 4-dimensional tensors for images. + C = dY.dim32(1); + H = dY.dim32(2); + W = dY.dim32(3); + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + data_desc_, miopenTypeWrapper<T>::type, dY.dim32(0), C, H, W)); + } + + size_t ws_size = 0; + MIOPEN_ENFORCE(miopenLRNGetWorkSpaceSize(data_desc_, &ws_size)); + if (bwdLRNWs_ == nullptr) { + HIP_CHECK(hipMalloc(&bwdLRNWs_, ws_size)); + } + + // Run fwd pass to populate workspace + if (bwdLRNScratch_ == nullptr) { + HIP_CHECK(hipMalloc(&bwdLRNScratch_, X.size() * sizeof(float))); + } + MIOPEN_ENFORCE(miopenLRNForward( + miopen_wrapper_.inline_miopen_handle(), + norm_desc_, + &alpha_, + data_desc_, + X.template data<T>(), + &beta_, + data_desc_, + bwdLRNScratch_, + true, + bwdLRNWs_)); + + // Run the bwd computation + MIOPEN_ENFORCE(miopenLRNBackward( + miopen_wrapper_.inline_miopen_handle(), + norm_desc_, + &alpha_, + data_desc_, + Y.template data<T>(), + data_desc_, + dY.template data<T>(), + data_desc_, + X.template data<T>(), + &beta_, + data_desc_, + dX->template mutable_data<T>(), + bwdLRNWs_)); + return true; +} + +bool MIOPENLRNGradientOp::RunOnDevice() { + // dispatch based on contents of tensor(s) + const auto& X = Input(0); + const auto& Y = Input(1); + const auto& dY = Input(2); + auto* dX = Output(0); + + dX->ResizeLike(dY); + + if (dY.IsType<float>()) { + return DoRunWithType<float, float>(); + } else { + CAFFE_THROW("Unsupported input type"); + } + return false; +} + +namespace { +REGISTER_MIOPEN_OPERATOR(LRN, MIOPEN_LRNOP); +REGISTER_MIOPEN_OPERATOR(LRNGradient, MIOPENLRNGradientOp); +} // namespace + +}; // namespace caffe2 diff --git a/caffe2/operators/hip/operator_fallback_hip.h b/caffe2/operators/hip/operator_fallback_hip.h new file mode 100644 index 0000000000..62e5fe8f01 --- /dev/null +++ b/caffe2/operators/hip/operator_fallback_hip.h @@ -0,0 +1,114 @@ +#ifndef CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_ +#define CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_ + +#include "caffe2/core/common.h" +#include "caffe2/core/context.h" +#include "caffe2/core/hip/context_hip.h" +#include "caffe2/core/operator.h" +#include "caffe2/proto/caffe2.pb.h" + +namespace caffe2 { + +/** + * @brief A templated class to allow one to wrap a CPU operator as a CUDA + * operator. + * + * This class can be used when one does not have the CUDA implementation ready + * yet for an operator. Essentially, what this op does is to automatically + * deal with data copy for you. Plausibly, this causes a lot of overhead and + * is not optimal, so you should use this operator mostly for quick prototyping + * purpose. + * + * All the input and output of the original operator should be TensorCPU. + * + * Example usage: if you have a class MyMagicOp that is CPU based, and you use + * the registration code + * REGISTER_CPU_OPERATOR(MyMagic, MyMagicOp); + * to register the CPU side, you can create its corresponding GPU operator + * (with performance hits of course) via + * REGISTER_HIP_OPERATOR(MyMagic, + * GPUFallbackOp<MyMagicOp>); + * + * Advanced usage: if you want to have some specific outputs never copied, you + * can use the SkipOutputCopy template argument to do that. For example, if + * MyMagic produces two outputs and the first output is always going to live on + * the CPU, you can do + * REGISTER_HIP_OPERATOR(MyMagic, + * GPUFallbackOp<MyMagicOp, SkipIndices<0>>); + */ +template <class CPUOp, typename SkipOutputCopy = SkipIndices<>> +class GPUFallbackOp final : public Operator<HIPContext> { + public: + USE_OPERATOR_FUNCTIONS(HIPContext); + GPUFallbackOp(const OperatorDef& def, Workspace* ws) + : Operator<HIPContext>(def, ws) { + CAFFE_ENFORCE_EQ(def.device_option().device_type(), HIP); + OperatorDef base_def_(def); + // base_def_ runs on CPU, so we will set its device option to CPU. + base_def_.clear_device_option(); + base_def_.mutable_device_option()->set_device_type(CPU); + // Set up the symbols for the local workspace. + for (const string& name : def.input()) { + local_input_blobs_.push_back(local_ws_.CreateBlob(name)); + CHECK_NOTNULL(local_input_blobs_.back()); + } + base_op_.reset(new CPUOp(base_def_, &local_ws_)); + for (const string& name : def.output()) { + local_output_blobs_.push_back(local_ws_.GetBlob(name)); + CHECK_NOTNULL(local_output_blobs_.back()); + } + } + + bool RunOnDevice() override { + bool need_sync = false; + for (int i = 0; i < InputSize(); ++i) { + if (OperatorBase::InputIsType<TensorHIP>(i)) { + local_input_blobs_[i]->template GetMutable<TensorCPU>()->CopyFrom( + Input(i), &context_); + need_sync = true; + } else { + VLOG(1) << "Input " << i << " is not TensorHIP. Skipping copy."; + // Note(jiayq): This removes a const but conceptually + // local_input_blobs will only be used as const blob input for the + // base op so we are still fine. + local_input_blobs_[i]->ShareExternal( + const_cast<void*>(OperatorBase::Inputs()[i]->GetRaw()), + OperatorBase::Inputs()[i]->meta()); + } + } + + // Sync to make sure copies are done. + if (need_sync) { + context_.FinishDeviceComputation(); + } + + if (!base_op_->Run()) { + LOG(ERROR) << "Base op run failed in GPUFallbackOp. Def: " + << ProtoDebugString(this->debug_def()); + return false; + } + for (int i = 0; i < OutputSize(); ++i) { + if (SkipOutputCopy::Contains(i)) { + VLOG(1) << "Copy output: index " << i << " skipped."; + continue; + } + CAFFE_ENFORCE( + local_output_blobs_[i]->template IsType<TensorCPU>(), + "GPU fallback op currently does not support non-TensorCPU " + "output type who needs copying."); + Output(i)->CopyFrom( + local_output_blobs_[i]->template Get<TensorCPU>(), &context_); + } + return true; + } + + protected: + Workspace local_ws_; + vector<Blob*> local_input_blobs_; + vector<Blob*> local_output_blobs_; + std::unique_ptr<CPUOp> base_op_; +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_ diff --git a/caffe2/operators/hip/operator_fallback_hip_test.cc b/caffe2/operators/hip/operator_fallback_hip_test.cc new file mode 100644 index 0000000000..4a074c35f8 --- /dev/null +++ b/caffe2/operators/hip/operator_fallback_hip_test.cc @@ -0,0 +1,80 @@ +#include <iostream> + +#include <gtest/gtest.h> +#include "caffe2/core/operator.h" +#include "caffe2/operators/hip/operator_fallback_hip.h" + +namespace caffe2 { + +class IncrementByOneOp final : public Operator<CPUContext> { + public: + IncrementByOneOp(const OperatorDef& def, Workspace* ws) + : Operator<CPUContext>(def, ws) {} + bool RunOnDevice() { + const auto& in = Input(0); + auto* out = Output(0); + out->ResizeLike(in); + const float* in_data = in.template data<float>(); + float* out_data = out->template mutable_data<float>(); + for (int i = 0; i < in.size(); ++i) { + out_data[i] = in_data[i] + 1.f; + } + return true; + } +}; + +OPERATOR_SCHEMA(IncrementByOne) + .NumInputs(1) + .NumOutputs(1) + .AllowInplace({{0, 0}}); + +REGISTER_CPU_OPERATOR(IncrementByOne, IncrementByOneOp); +REGISTER_HIP_OPERATOR(IncrementByOne, GPUFallbackOp<IncrementByOneOp>); + +TEST(OperatorFallbackTest, IncrementByOneOp) { + OperatorDef op_def = CreateOperatorDef( + "IncrementByOne", "", vector<string>{"X"}, vector<string>{"X"}); + Workspace ws; + TensorCPU source_tensor(vector<TIndex>{2, 3}); + for (int i = 0; i < 6; ++i) { + source_tensor.mutable_data<float>()[i] = i; + } + ws.CreateBlob("X")->GetMutable<TensorCPU>()->CopyFrom(source_tensor); + unique_ptr<OperatorBase> op(CreateOperator(op_def, &ws)); + EXPECT_TRUE(op.get() != nullptr); + EXPECT_TRUE(op->Run()); + const TensorCPU& output = ws.GetBlob("X")->Get<TensorCPU>(); + EXPECT_EQ(output.ndim(), 2); + EXPECT_EQ(output.dim(0), 2); + EXPECT_EQ(output.dim(1), 3); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(output.data<float>()[i], i + 1); + } +} + +TEST(OperatorFallbackTest, GPUIncrementByOneOp) { + if (!HasHipGPU()) + return; + OperatorDef op_def = CreateOperatorDef( + "IncrementByOne", "", vector<string>{"X"}, vector<string>{"X"}); + op_def.mutable_device_option()->set_device_type(HIP); + Workspace ws; + TensorCPU source_tensor(vector<TIndex>{2, 3}); + for (int i = 0; i < 6; ++i) { + source_tensor.mutable_data<float>()[i] = i; + } + ws.CreateBlob("X")->GetMutable<TensorHIP>()->CopyFrom(source_tensor); + unique_ptr<OperatorBase> op(CreateOperator(op_def, &ws)); + EXPECT_TRUE(op.get() != nullptr); + EXPECT_TRUE(op->Run()); + const TensorHIP& output = ws.GetBlob("X")->Get<TensorHIP>(); + TensorCPU output_cpu(output); + EXPECT_EQ(output.ndim(), 2); + EXPECT_EQ(output.dim(0), 2); + EXPECT_EQ(output.dim(1), 3); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(output_cpu.data<float>()[i], i + 1); + } +} + +} // namespace caffe2 diff --git a/caffe2/operators/hip/pool_op_miopen.cc b/caffe2/operators/hip/pool_op_miopen.cc new file mode 100644 index 0000000000..9f9e7f930e --- /dev/null +++ b/caffe2/operators/hip/pool_op_miopen.cc @@ -0,0 +1,310 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "caffe2/core/hip/context_hip.h" +#include "caffe2/core/hip/miopen_wrapper.h" +#include "caffe2/operators/conv_pool_op_base.h" + +namespace caffe2 { +class MIOPENPoolOp : public ConvPoolOpBase<HIPContext> { + public: + MIOPENPoolOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase<HIPContext>(operator_def, ws), + miopen_wrapper_(&context_), + alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)), + beta_(OperatorBase::GetSingleArgument<float>("beta", 0.0)), + do_backward_( + OperatorBase::GetSingleArgument<bool>("do_backward", true)), + poolWs_(nullptr), + poolWsSize_(0) + + { + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&bottom_desc_)); + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&top_desc_)); + MIOPEN_ENFORCE(miopenCreatePoolingDescriptor(&pooling_desc_)); + + if ((operator_def.type().substr(0, 9) == "MaxPool") || + (operator_def.type().substr(0, 17) == "MaxPoolGradient")) { + mode_ = miopenPoolingMax; + } else if ( + (operator_def.type().substr(0, 13) == "AveragePool") || + (operator_def.type().substr(0, 21) == "AveragePoolGradient")) { + mode_ = miopenPoolingAverage; + } else { + LOG(FATAL) << "Unsupported pooling method: " << operator_def.type(); + } + } + + ~MIOPENPoolOp() { + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(bottom_desc_)); + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(top_desc_)); + MIOPEN_ENFORCE(miopenDestroyPoolingDescriptor(pooling_desc_)); + poolWsSize_ = 0; + + if (poolWs_ != nullptr) { + hipFree(poolWs_); + poolWs_ = nullptr; + } + } + + template <typename T, typename M> + bool DoRunWithType() { + auto& X = Input(0); + auto* Y = Output(0); + int N = 0, C = 0, H = 0, W = 0, D = 0; + int N_out = 0, C_out = 0, H_out = 0, W_out = 0; + CAFFE_ENFORCE(X.ndim() >= 4 && X.ndim() <= 5); + N = X.dim32(0); + C = X.dim32(1); + H = X.dim32(2); + W = X.ndim() > 3 ? X.dim32(3) : 1; + ConvPoolOpBase::SetOutputSize(X, Y, C); + + N_out = Y->dim32(0); + C_out = Y->dim32(1); + H_out = Y->dim32(2); + W_out = Y->ndim() > 3 ? Y->dim32(3) : 1; + + CAFFE_ENFORCE(kernel_.size() == 2, "MIOpen supports only 2D pooling"); + MIOPEN_ENFORCE(miopenSet2dPoolingDescriptor( + pooling_desc_, + mode_, + kernel_h(), + kernel_w(), + pad_t(), + pad_l(), + stride_h(), + stride_w())); + + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + bottom_desc_, miopenTypeWrapper<T>::type, N, C, H, W)); + + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + top_desc_, miopenTypeWrapper<T>::type, N_out, C_out, H_out, W_out)); + + MIOPEN_ENFORCE(miopenPoolingGetWorkSpaceSize(top_desc_, &poolWsSize_)); + + if ((poolWsSize_ > 0) && (poolWs_ == nullptr)) { + HIP_CHECK(hipMalloc(&poolWs_, poolWsSize_)); + } + + const T* Xdata = X.template data<T>(); + T* Ydata = Y->template mutable_data<T>(); + MIOPEN_ENFORCE(miopenPoolingForward( + miopen_wrapper_.inline_miopen_handle(), + pooling_desc_, + &alpha_, + bottom_desc_, + Xdata, + &beta_, + top_desc_, + Ydata, + do_backward_, + poolWs_, + poolWsSize_)); + + return true; + } + + bool RunOnDevice() final { + auto& X = Input(0); + auto* Y = Output(0); + // TODO enable fp16 + if (X.IsType<float>()) { + return DoRunWithType<float, float>(); + } else { + LOG(FATAL) << "Unsupported input types"; + } + return true; + } + + protected: + size_t poolWsSize_; + char* poolWs_; + MIOPENWrapper miopen_wrapper_; + miopenTensorDescriptor_t bottom_desc_; + miopenTensorDescriptor_t top_desc_; + miopenPoolingDescriptor_t pooling_desc_; + miopenPoolingMode_t mode_; + bool do_backward_; + const float alpha_; + const float beta_; +}; + +class MIOPENPoolGradientOp : public ConvPoolOpBase<HIPContext> { + public: + MIOPENPoolGradientOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase<HIPContext>(operator_def, ws), + miopen_wrapper_(&context_), + alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)), + beta_(OperatorBase::GetSingleArgument<float>("beta", 0.0)), + poolWs_(nullptr), + poolWsSize_(0), + bwdPoolScratch_(nullptr) { + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&bottom_desc_)); + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&top_desc_)); + MIOPEN_ENFORCE(miopenCreatePoolingDescriptor(&pooling_desc_)); + + if (operator_def.type().substr(0, 7) == "MaxPool") { + mode_ = miopenPoolingMax; + } else if (operator_def.type().substr(0, 11) == "AveragePool") { + mode_ = miopenPoolingAverage; + } else { + LOG(FATAL) << "Unsupported pooling method: " << operator_def.type(); + } + } + + ~MIOPENPoolGradientOp() { + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(bottom_desc_)); + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(top_desc_)); + MIOPEN_ENFORCE(miopenDestroyPoolingDescriptor(pooling_desc_)); + poolWsSize_ = 0; + + if (poolWs_ != nullptr) { + hipFree(poolWs_); + poolWs_ = nullptr; + } + + if (bwdPoolScratch_) { + hipFree(bwdPoolScratch_); + bwdPoolScratch_ = nullptr; + } + } + + template <typename T, typename M> + bool DoRunWithType() { + auto& X = Input(0); + auto& Y = Input(1); + auto& dY = Input(2); + auto* dX = Output(0); + + // cuDNN pooling support only 2 and 3 spatial dimensions. + CAFFE_ENFORCE(X.ndim() >= 4 && X.ndim() <= 5); + + dX->ResizeLike(X); + int N = 0, C = 0, H = 0, W = 0, D = 0; + int N_out = 0, C_out = 0, H_out = 0, W_out = 0, D_out = 0; + N = X.dim32(0); + C = X.dim32(1); + H = X.dim32(2); + W = X.ndim() > 3 ? X.dim32(3) : 1; + D = X.ndim() > 4 ? X.dim32(4) : 1; + N_out = Y.dim32(0); + C_out = Y.dim32(1); + H_out = Y.dim32(2); + W_out = Y.ndim() > 3 ? Y.dim32(3) : 1; + D_out = Y.ndim() > 4 ? Y.dim32(4) : 1; + + CAFFE_ENFORCE(kernel_.size() == 2, "MIOpen supports only 2D pooling"); + MIOPEN_ENFORCE(miopenSet2dPoolingDescriptor( + pooling_desc_, + mode_, + kernel_h(), + kernel_w(), + pad_t(), + pad_l(), + stride_h(), + stride_w())); + + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + bottom_desc_, miopenTypeWrapper<T>::type, N, C, H, W)); + + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + top_desc_, miopenTypeWrapper<T>::type, N_out, C_out, H_out, W_out)); + + MIOPEN_ENFORCE(miopenPoolingGetWorkSpaceSize(top_desc_, &poolWsSize_)); + + if ((poolWsSize_ > 0) && (poolWs_ == nullptr)) { + HIP_CHECK(hipMalloc(&poolWs_, poolWsSize_)); + } + + if (bwdPoolScratch_ == nullptr) { + HIP_CHECK(hipMalloc(&bwdPoolScratch_, Y.size() * sizeof(float))); + } + + // Carry out the pooling computation. + const T* Xdata = X.template data<T>(); + const T* Ydata = Y.template data<T>(); + const T* dYdata = dY.template data<T>(); + T* dXdata = dX->template mutable_data<T>(); + + MIOPEN_ENFORCE(miopenPoolingForward( + miopen_wrapper_.inline_miopen_handle(), + pooling_desc_, + &alpha_, + bottom_desc_, + Xdata, + &beta_, + top_desc_, + bwdPoolScratch_, + true, + poolWs_, + poolWsSize_)); + + MIOPEN_ENFORCE(miopenPoolingBackward( + miopen_wrapper_.inline_miopen_handle(), + pooling_desc_, + &alpha_, + top_desc_, + Ydata, + top_desc_, + dYdata, + bottom_desc_, + Xdata, + &beta_, + bottom_desc_, + dXdata, + poolWs_)); + + return true; + } + + bool RunOnDevice() final { + auto& X = Input(0); + auto& Y = Input(1); + auto& dY = Input(2); + auto* dX = Output(0); + dX->ResizeLike(X); + + if (X.IsType<float>()) { + return DoRunWithType<float, float>(); + } else { + LOG(FATAL) << "Unsupported input types"; + } + return true; + } + + protected: + size_t poolWsSize_; + char* poolWs_; + MIOPENWrapper miopen_wrapper_; + miopenTensorDescriptor_t bottom_desc_; + miopenTensorDescriptor_t top_desc_; + miopenPoolingDescriptor_t pooling_desc_; + miopenPoolingMode_t mode_; + const float alpha_; + const float beta_; + float* bwdPoolScratch_; +}; + +namespace { +REGISTER_MIOPEN_OPERATOR(AveragePool, MIOPENPoolOp); +REGISTER_MIOPEN_OPERATOR(AveragePoolGradient, MIOPENPoolGradientOp); + +REGISTER_MIOPEN_OPERATOR(MaxPool, MIOPENPoolOp); +REGISTER_MIOPEN_OPERATOR(MaxPoolGradient, MIOPENPoolGradientOp); +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/hip/relu_op_miopen.cc b/caffe2/operators/hip/relu_op_miopen.cc new file mode 100644 index 0000000000..5a8a147ff2 --- /dev/null +++ b/caffe2/operators/hip/relu_op_miopen.cc @@ -0,0 +1,205 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "caffe2/core/hip/context_hip.h" +#include "caffe2/core/hip/miopen_wrapper.h" +#include "caffe2/core/operator.h" +#include "caffe2/core/types.h" + +namespace caffe2 { + +class MIOPENReluOp final : public Operator<HIPContext> { + public: + MIOPENReluOp(const OperatorDef& operator_def, Workspace* ws) + : Operator<HIPContext>(operator_def, ws), + miopen_wrapper_(&context_), + alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)), + beta_(OperatorBase::GetSingleArgument<float>("beta", 0.0)), + power_(OperatorBase::GetSingleArgument<double>("power", 1.0)) { + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&data_desc_)); + MIOPEN_ENFORCE(miopenCreateActivationDescriptor(&activ_desc_)); + MIOPEN_ENFORCE(miopenSetActivationDescriptor( + activ_desc_, miopenActivationRELU, alpha_, beta_, power_)); + } + + ~MIOPENReluOp() { + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(data_desc_)); + MIOPEN_ENFORCE(miopenDestroyActivationDescriptor(activ_desc_)); + } + + template <typename T> + bool DoRunWithType() { + const auto& X = Input(0); + auto* Y = Output(0); + + // Return if X is empty + if (X.size() == 0) { + Y->mutable_data<T>(); + return true; + } + + // See if we need to reshape. + if (X.dims() != miopen_input_dims_) { + VLOG(1) << "Setting descriptors."; + miopen_input_dims_ = X.dims(); + int C = 1, H = 1, W = 1; + if (X.ndim() == 4) { + // Normal 4-dimensional tensors for images. + C = X.dim32(1); + H = X.dim32(2); + W = X.dim32(3); + } else { + // If X is not 4-dimensional, we will simply use H = 1 and W = 1 + // and wrap everything into C. + C = X.size() / X.dim32(0); + } + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + data_desc_, miopenTypeWrapper<T>::type, X.dim32(0), C, H, W)); + } + MIOPEN_ENFORCE(miopenActivationForward( + miopen_wrapper_.inline_miopen_handle(), + activ_desc_, + &alpha_, + data_desc_, + X.template data<T>(), + &beta_, + data_desc_, + Y->template mutable_data<T>())); + return true; + } + + bool RunOnDevice() override { + // dispatch based on contents of tensor(s) + const auto& X = Input(0); + auto* Y = Output(0); + Y->ResizeLike(X); + if (X.IsType<float>()) { + return DoRunWithType<float>(); + } else { + LOG(FATAL) << "Unsupported input types"; + } + return true; + } + + protected: + MIOPENWrapper miopen_wrapper_; + miopenTensorDescriptor_t data_desc_; + miopenActivationDescriptor_t activ_desc_; + vector<TIndex> miopen_input_dims_; + const float alpha_; + const float beta_; + const double power_; +}; + +// Note: You can see that in MIOPENReluGradientOp, we abused the miopen +// interface by passing in the output tensor for both bottom and top. This is +// dependent on the assumption that the Relu gradient actually does not rely on +// the bottom data, or it treats input=0 the same way as input<0. This is of +// course not very safe, but we have been running in this way in Caffe for a +// while so it *might* be safe to assume so. +class MIOPENReluGradientOp final : public Operator<HIPContext> { + public: + MIOPENReluGradientOp(const OperatorDef& operator_def, Workspace* ws) + : Operator<HIPContext>(operator_def, ws), + miopen_wrapper_(&context_), + alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)), + beta_(OperatorBase::GetSingleArgument<float>("beta", 0.0)), + power_(OperatorBase::GetSingleArgument<double>("power", 1.0)) { + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&data_desc_)); + MIOPEN_ENFORCE(miopenCreateActivationDescriptor(&activ_desc_)); + MIOPEN_ENFORCE(miopenSetActivationDescriptor( + activ_desc_, miopenActivationRELU, alpha_, beta_, power_)); + } + + ~MIOPENReluGradientOp() { + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(data_desc_)); + MIOPEN_ENFORCE(miopenDestroyActivationDescriptor(activ_desc_)); + } + + template <typename T> + bool DoRunWithType() { + const auto& Y = Input(0); + const auto& dY = Input(1); + auto* dX = Output(0); + + // Return if Y is empty + if (Y.size() == 0) { + dX->mutable_data<T>(); + return true; + } + + // See if we need to reshape. + if (Y.dims() != miopen_input_dims_) { + VLOG(1) << "Setting descriptors."; + miopen_input_dims_ = Y.dims(); + int C = 1, H = 1, W = 1; + if (Y.ndim() == 4) { + // Normal 4-dimensional tensors for images. + C = Y.dim32(1); + H = Y.dim32(2); + W = Y.dim32(3); + } else { + // If Y is not 4-dimensional, we will simply use H = 1 and W = 1 + // and wrap everything into C. + C = Y.size() / Y.dim32(0); + } + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + data_desc_, miopenTypeWrapper<T>::type, Y.dim32(0), C, H, W)); + } + MIOPEN_ENFORCE(miopenActivationBackward( + miopen_wrapper_.inline_miopen_handle(), + activ_desc_, + &alpha_, + data_desc_, + Y.template data<T>(), + data_desc_, + dY.template data<T>(), + data_desc_, + Y.template data<T>(), + &beta_, + data_desc_, + dX->template mutable_data<T>())); + return true; + } + + bool RunOnDevice() override { + const auto& Y = Input(0); + auto* dX = Output(0); + dX->ResizeLike(Y); + if (Y.IsType<float>()) { + return DoRunWithType<float>(); + } else { + LOG(FATAL) << "Unsupported input types"; + } + return true; + } + + protected: + MIOPENWrapper miopen_wrapper_; + miopenTensorDescriptor_t data_desc_; + miopenActivationDescriptor_t activ_desc_; + vector<TIndex> miopen_input_dims_; + const float alpha_; + const float beta_; + const double power_; + // Input: Y, dY; Output: dX +}; + +namespace { +REGISTER_MIOPEN_OPERATOR(Relu, MIOPENReluOp); +REGISTER_MIOPEN_OPERATOR(ReluGradient, MIOPENReluGradientOp); +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/hip/softmax_op_miopen.cc b/caffe2/operators/hip/softmax_op_miopen.cc new file mode 100644 index 0000000000..08c43a8aa2 --- /dev/null +++ b/caffe2/operators/hip/softmax_op_miopen.cc @@ -0,0 +1,138 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "caffe2/core/hip/context_hip.h" +#include "caffe2/core/hip/miopen_wrapper.h" +#include "caffe2/core/types.h" +#include "caffe2/operators/softmax_op.h" + +namespace caffe2 { +class MIOpenSoftmaxOp final : public Operator<HIPContext> { + public: + explicit MIOpenSoftmaxOp(const OperatorDef& def, Workspace* ws) + : Operator<HIPContext>(def, ws), + miopen_wrapper_(&context_), + axis_(OperatorBase::GetSingleArgument<int>("axis", 1)), + alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)), + beta_(OperatorBase::GetSingleArgument<float>("beta", 0.0)) { + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&desc_)); + } + + ~MIOpenSoftmaxOp() { + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(desc_)); + } + + template <typename T> + bool DoRunWithType() { + auto& X = Input(0); + auto* Y = Output(0); + const auto canonical_axis = X.canonical_axis_index(axis_); + const int N = X.size_to_dim(canonical_axis); + const int D = X.size_from_dim(canonical_axis); + + Y->ResizeLike(X); + if (dims_ != X.dims()) { + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + desc_, miopenTypeWrapper<T>::type, N, D, 1, 1)); + dims_ = X.dims(); + } + MIOPEN_ENFORCE(miopenSoftmaxForward( + miopen_wrapper_.inline_miopen_handle(), + &alpha_, + desc_, + X.template data<T>(), + &beta_, + desc_, + Y->template mutable_data<T>())); + return true; + } + + bool RunOnDevice() override { + return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0)); + } + + protected: + MIOPENWrapper miopen_wrapper_; + miopenTensorDescriptor_t desc_; + vector<TIndex> dims_; + const int axis_; + const float alpha_; + const float beta_; +}; + +class MIOpenSoftmaxGradientOp final : public Operator<HIPContext> { + public: + explicit MIOpenSoftmaxGradientOp(const OperatorDef& def, Workspace* ws) + : Operator<HIPContext>(def, ws), + miopen_wrapper_(&context_), + axis_(OperatorBase::GetSingleArgument<int>("axis", 1)), + alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)), + beta_(OperatorBase::GetSingleArgument<float>("beta", 0.0)) { + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&desc_)); + } + + ~MIOpenSoftmaxGradientOp() { + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(desc_)); + } + + template <typename T> + bool DoRunWithType() { + auto& Y = Input(0); + auto& dY = Input(1); + auto* dX = Output(0); + const auto canonical_axis = Y.canonical_axis_index(axis_); + const int N = Y.size_to_dim(canonical_axis); + const int D = Y.size_from_dim(canonical_axis); + + CHECK_EQ(Y.dims(), dY.dims()); + dX->ResizeLike(Y); + if (dims_ != Y.dims()) { + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + desc_, miopenTypeWrapper<T>::type, N, D, 1, 1)); + dims_ = Y.dims(); + } + MIOPEN_ENFORCE(miopenSoftmaxBackward( + miopen_wrapper_.inline_miopen_handle(), + &alpha_, + desc_, + Y.template data<T>(), + desc_, + dY.template data<T>(), + &beta_, + desc_, + dX->template mutable_data<T>())); + return true; + } + + bool RunOnDevice() override { + return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0)); + } + + protected: + MIOPENWrapper miopen_wrapper_; + const int axis_; + const float alpha_; + const float beta_; + miopenTensorDescriptor_t desc_; + vector<TIndex> dims_; +}; + +namespace { +REGISTER_MIOPEN_OPERATOR(Softmax, MIOpenSoftmaxOp); +REGISTER_MIOPEN_OPERATOR(SoftmaxGradient, MIOpenSoftmaxGradientOp); +} // namespace + +} // namespace caffe2 diff --git a/caffe2/operators/hip/spatial_batch_norm_op_miopen.cc b/caffe2/operators/hip/spatial_batch_norm_op_miopen.cc new file mode 100644 index 0000000000..77f35b1334 --- /dev/null +++ b/caffe2/operators/hip/spatial_batch_norm_op_miopen.cc @@ -0,0 +1,318 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <cfloat> +#include "caffe2/core/hip/context_hip.h" +#include "caffe2/core/hip/miopen_wrapper.h" +#include "caffe2/operators/spatial_batch_norm_op.h" +#include "caffe2/utils/math.h" + +const double MIOPEN_BN_MIN_EPSILON = 1e-6; + +namespace caffe2 { + +class MIOpenSpatialBNOp final : public SpatialBNOp<HIPContext> { + public: + USE_OPERATOR_FUNCTIONS(HIPContext); + MIOpenSpatialBNOp(const OperatorDef& operator_def, Workspace* ws) + : SpatialBNOp<HIPContext>(operator_def, ws), + miopen_wrapper_(&context_), + alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)), + beta_(OperatorBase::GetSingleArgument<float>("beta", 0.0)), + mode_(miopenBNSpatial) { + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&data_desc_)); + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&bn_param_desc_)); + if (epsilon_ <= MIOPEN_BN_MIN_EPSILON) { + LOG(ERROR) << "Provided epsilon is smaller than " + << "MIOPEN_BN_MIN_EPSILON. Setting it to " + << "MIOPEN_BN_MIN_EPSILON instead."; + } + epsilon_ = std::max(epsilon_, MIOPEN_BN_MIN_EPSILON); + } + + ~MIOpenSpatialBNOp() { + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(data_desc_)); + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(bn_param_desc_)); + } + + template <typename T, typename M> + bool DoRunWithType(); + bool RunOnDevice() override; + + protected: + MIOPENWrapper miopen_wrapper_; + miopenTensorDescriptor_t data_desc_; + miopenTensorDescriptor_t bn_param_desc_; + vector<TIndex> miopen_input_dims_; + float alpha_; + float beta_; + miopenBatchNormMode_t mode_; +}; + +class MIOpenSpatialBNGradientOp final : public SpatialBNGradientOp<HIPContext> { + public: + USE_OPERATOR_FUNCTIONS(HIPContext); + MIOpenSpatialBNGradientOp(const OperatorDef& operator_def, Workspace* ws) + : SpatialBNGradientOp<HIPContext>(operator_def, ws), + miopen_wrapper_(&context_), + alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)), + beta_(OperatorBase::GetSingleArgument<float>("beta", 0.0)), + mode_(miopenBNSpatial) { + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&data_desc_)); + MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&bn_param_desc_)); + if (epsilon_ <= MIOPEN_BN_MIN_EPSILON) { + LOG(ERROR) << "Provided epsilon is smaller than " + << "MIOPEN_BN_MIN_EPSILON. Setting it to " + << "MIOPEN_BN_MIN_EPSILON instead."; + } + epsilon_ = std::max(epsilon_, MIOPEN_BN_MIN_EPSILON); + } + + ~MIOpenSpatialBNGradientOp() { + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(data_desc_)); + MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(bn_param_desc_)); + } + + template <typename T, typename M> + bool DoRunWithType(); + + bool RunOnDevice() override; + + protected: + MIOPENWrapper miopen_wrapper_; + miopenTensorDescriptor_t data_desc_; + miopenTensorDescriptor_t bn_param_desc_; + vector<TIndex> miopen_input_dims_; + float alpha_; + float beta_; + miopenBatchNormMode_t mode_; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Implementations +//////////////////////////////////////////////////////////////////////////////// + +template <typename T, typename M> +bool MIOpenSpatialBNOp::DoRunWithType() { + // QoL + typedef typename miopenTypeWrapper<T>::BNParamType BNParamType; + + auto& X = Input(INPUT); + auto& scale = Input(SCALE); + auto& bias = Input(BIAS); + + CAFFE_ENFORCE_GE(X.ndim(), 3); + const int N = X.dim32(0); + const int C = X.dim32(1); + const int H = X.dim32(2); + const int W = X.ndim() > 3 ? X.dim32(3) : 1; + const int D = X.ndim() > 4 ? X.dim32(4) : 1; + CAFFE_ENFORCE_EQ(scale.ndim(), 1); + CAFFE_ENFORCE_EQ(bias.ndim(), 1); + CAFFE_ENFORCE_EQ(scale.dim32(0), C); + CAFFE_ENFORCE_EQ(bias.dim32(0), C); + // See if we need to reshape. + if (X.dims() != miopen_input_dims_) { + VLOG(1) << "Setting descriptors."; + miopen_input_dims_ = X.dims(); + vector<int> dims = {N, C, H, W, D}; + vector<int> strides = {C * H * W * D, H * W * D, W * D, D, 1}; + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + data_desc_, miopenTypeWrapper<T>::type, N, C, H, W)); + + MIOPEN_ENFORCE( + miopenDeriveBNTensorDescriptor(bn_param_desc_, data_desc_, mode_)); + } + + // Now, depending on whether we are running test or not, we have two paths. + if (is_test_) { + // Run inference mode. + auto& est_mean = Input(EST_MEAN); + auto& est_var = Input(EST_VAR); + CAFFE_ENFORCE_EQ(est_mean.ndim(), 1); + CAFFE_ENFORCE_EQ(est_var.ndim(), 1); + CAFFE_ENFORCE_EQ(est_mean.dim32(0), C); + CAFFE_ENFORCE_EQ(est_var.dim32(0), C); + + auto* Y = Output(OUTPUT); + Y->ResizeLike(X); + MIOPEN_ENFORCE(miopenBatchNormalizationForwardInference( + miopen_wrapper_.inline_miopen_handle(), + // Note: PERSISTENT not implemented for inference + mode_, + &alpha_, + &beta_, + data_desc_, + X.template data<T>(), + data_desc_, + Y->template mutable_data<T>(), + bn_param_desc_, + const_cast<float*>(scale.template data<BNParamType>()), + const_cast<float*>(bias.template data<BNParamType>()), + const_cast<float*>(est_mean.template data<BNParamType>()), + const_cast<float*>(est_var.template data<BNParamType>()), + epsilon_)); + } else { + // Run training mode. + auto* Y = Output(OUTPUT); + Y->ResizeLike(X); + // obtain running mean and running inv var, and see if we need to + // initialize them. + auto* running_mean = Output(RUNNING_MEAN); + auto* running_var = Output(RUNNING_VAR); + double this_factor = 1. - momentum_; + BNParamType* running_mean_data = nullptr; + BNParamType* running_var_data = nullptr; + if (!running_mean->size()) { + // If the input mean and var are not initialized yet, this is the first + // run and we will initialize the storage. + VLOG(1) << "Initializing running mean and var."; + // Need to do initialization + running_mean->Resize(C); + running_var->Resize(C); + running_mean_data = running_mean->template mutable_data<BNParamType>(); + running_var_data = running_var->template mutable_data<BNParamType>(); + // In principle, setting this_momentum to 1 will wipe existing data. + // This has a caveat that if miopen does not deal with 0*NaN cases we + // will be having an issue. Thus we choose a safe path by explicitly + // setting zero. + math::Set<BNParamType, HIPContext>(C, 0, running_mean_data, &context_); + math::Set<BNParamType, HIPContext>(C, 0, running_var_data, &context_); + } else { + // Does not need to do initialization. + CAFFE_ENFORCE_EQ(running_mean->ndim(), 1); + CAFFE_ENFORCE_EQ(running_var->ndim(), 1); + CAFFE_ENFORCE_EQ(running_mean->dim32(0), C); + CAFFE_ENFORCE_EQ(running_var->dim32(0), C); + running_mean_data = running_mean->template mutable_data<BNParamType>(); + running_var_data = running_var->template mutable_data<BNParamType>(); + } + // Save the mean and inv var results. + auto* save_mean = Output(SAVED_MEAN); + auto* save_var = Output(SAVED_INV_VAR); + save_mean->Resize(C); + save_var->Resize(C); + void* save_mean_data = save_mean->template mutable_data<BNParamType>(); + void* save_var_data = save_var->template mutable_data<BNParamType>(); + + MIOPEN_ENFORCE(miopenBatchNormalizationForwardTraining( + miopen_wrapper_.inline_miopen_handle(), + mode_, + &alpha_, + &beta_, + data_desc_, + X.template data<T>(), + data_desc_, + Y->template mutable_data<T>(), + bn_param_desc_, + const_cast<float*>(scale.template data<BNParamType>()), + const_cast<float*>(bias.template data<BNParamType>()), + this_factor, + const_cast<float*>(running_mean_data), + const_cast<float*>(running_var_data), + epsilon_, + save_mean_data, + save_var_data)); + } + return true; +} +bool MIOpenSpatialBNOp::RunOnDevice() { + if (Input(0).IsType<float>()) { + return DoRunWithType<float, float>(); + } else { + LOG(FATAL) << "Unsupported input types"; + } + return true; +} + +template <typename T, typename M> +bool MIOpenSpatialBNGradientOp::DoRunWithType() { + typedef typename miopenTypeWrapper<T>::BNParamType BNParamType; + + auto& X = Input(INPUT); + auto& scale = Input(SCALE); + auto& dY = Input(OUTPUT_GRAD); + + CAFFE_ENFORCE_GE(X.ndim(), 3); + const int N = X.dim32(0); + const int C = X.dim32(1); + const int H = X.dim32(2); + const int W = X.ndim() > 3 ? X.dim32(3) : 1; + const int D = X.ndim() > 4 ? X.dim32(4) : 1; + CAFFE_ENFORCE_EQ(scale.ndim(), 1); + CAFFE_ENFORCE_EQ(scale.dim32(0), C); + // See if we need to reshape. + if (X.dims() != miopen_input_dims_) { + vector<int> dims = {N, C, H, W, D}; + vector<int> strides = {C * H * W * D, H * W * D, W * D, D, 1}; + MIOPEN_ENFORCE(miopenSet4dTensorDescriptor( + data_desc_, miopenTypeWrapper<T>::type, N, C, H, W)); + + MIOPEN_ENFORCE( + miopenDeriveBNTensorDescriptor(bn_param_desc_, data_desc_, mode_)); + } + + auto* dX = Output(INPUT_GRAD); + auto* dScale = Output(SCALE_GRAD); + auto* dBias = Output(BIAS_GRAD); + dX->ResizeLike(X); + dScale->ResizeLike(scale); + dBias->ResizeLike(scale); + + const auto& saved_mean = Input(SAVED_MEAN); + const auto& saved_var = Input(SAVED_INV_VAR); + const void* saved_mean_data = saved_mean.template data<BNParamType>(); + const void* saved_var_data = saved_var.template data<BNParamType>(); + + MIOPEN_ENFORCE(miopenBatchNormalizationBackward( + miopen_wrapper_.inline_miopen_handle(), + mode_, + &alpha_, + &beta_, + &alpha_, + &beta_, + data_desc_, + X.template data<T>(), + data_desc_, + dY.template data<T>(), + data_desc_, + dX->template mutable_data<T>(), + bn_param_desc_, + scale.template data<BNParamType>(), + dScale->template mutable_data<BNParamType>(), + dBias->template mutable_data<BNParamType>(), + epsilon_, + saved_mean_data, + saved_var_data)); + return true; +} +bool MIOpenSpatialBNGradientOp::RunOnDevice() { + if (Input(0).IsType<float>()) { + return DoRunWithType<float, float>(); + } else { + LOG(FATAL) << "Unsupported input types"; + } + return true; +} + +// Since there is no default implementation for spatial batch normalization, +// we will register the miopen version as the default as well. +REGISTER_HIP_OPERATOR(SpatialBN, MIOpenSpatialBNOp); +REGISTER_HIP_OPERATOR(SpatialBNGradient, MIOpenSpatialBNGradientOp); + +REGISTER_MIOPEN_OPERATOR(SpatialBN, MIOpenSpatialBNOp); +REGISTER_MIOPEN_OPERATOR(SpatialBNGradient, MIOpenSpatialBNGradientOp); +} // namespace caffe2 |