diff options
author | rohithkrn <rohith.nallamaddi@gmail.com> | 2018-12-14 16:31:34 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-12-14 16:33:45 -0800 |
commit | 763b9954f3fbe0b2058eee9fe8055dfdc58ee615 (patch) | |
tree | 44dd4ac008bcf9beb956ab286df615bccccabd66 | |
parent | e596d2313761ce0c93f2f6d379344aa22620b06b (diff) | |
download | pytorch-763b9954f3fbe0b2058eee9fe8055dfdc58ee615.tar.gz pytorch-763b9954f3fbe0b2058eee9fe8055dfdc58ee615.tar.bz2 pytorch-763b9954f3fbe0b2058eee9fe8055dfdc58ee615.zip |
FP16MomentumSGDUpdate Op fix and enable for ROCm (#15150)
Summary:
1. Fix a bug in FP16MomentumSGDUpdate operator
2. Enable operator for ROCm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15150
Differential Revision: D13473145
Pulled By: bddppq
fbshipit-source-id: 4c5c5f30cb9bba658e3639dbe193fa08a304d306
-rw-r--r-- | caffe2/python/operator_test/momentum_sgd_test.py | 8 | ||||
-rw-r--r-- | caffe2/sgd/fp16_momentum_sgd_op.cu | 6 |
2 files changed, 8 insertions, 6 deletions
diff --git a/caffe2/python/operator_test/momentum_sgd_test.py b/caffe2/python/operator_test/momentum_sgd_test.py index 27dcb78c14..bcd0631951 100644 --- a/caffe2/python/operator_test/momentum_sgd_test.py +++ b/caffe2/python/operator_test/momentum_sgd_test.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals +from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial @@ -143,7 +144,7 @@ class TestMomentumSGD(serial.SerializedTestCase): def test_fp16momentum_sgd(self, n, nesterov, gc, dc): assume(core.IsGPUDeviceType(gc.device_type)) gpuvers = workspace.GetDeviceProperties(0)["major"] - if gpuvers < 6: + if gc.device_type == caffe2_pb2.CUDA and gpuvers < 6: print("No FP16 support because major version {} < 6".format(gpuvers)) return @@ -152,7 +153,6 @@ class TestMomentumSGD(serial.SerializedTestCase): lr = np.random.rand(1).astype(np.float32) param_momentum = np.random.rand(n).astype(np.float16) momentum = 0.9 - nesterov = True def momentum_sgd(grad, param_momentum, lr, param=None): if not nesterov: @@ -174,11 +174,13 @@ class TestMomentumSGD(serial.SerializedTestCase): weight_decay=0.0, ) + threshold = 1e-3 if (gc.device_type == caffe2_pb2.HIP) else 1e-4 self.assertReferenceChecks( device_option=gc, op=op, inputs=[grad, param_momentum, lr, param], - reference=momentum_sgd + reference=momentum_sgd, + threshold=threshold ) diff --git a/caffe2/sgd/fp16_momentum_sgd_op.cu b/caffe2/sgd/fp16_momentum_sgd_op.cu index b7ac0a7b76..8ec1c85fd5 100644 --- a/caffe2/sgd/fp16_momentum_sgd_op.cu +++ b/caffe2/sgd/fp16_momentum_sgd_op.cu @@ -22,7 +22,7 @@ __global__ void FP16MomentumSGDKernel( bool nesterov, const float wd, half2* param) { -#if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ >= 530 || defined(__HIP_PLATFORM_HCC__) const float lr2 = lr[0]; const half2 LR = __float2half2_rn(lr2); const half2 momentum = __float2half2_rn(mom); @@ -87,7 +87,7 @@ __global__ void FP16MomentumSGDKernel( __hfma(mi_new_half, __high2half(momentum), mi_new_half), mom_mi_half); if (param) { - param_half[N - 1] = __hsub(param_half[i], ng_half[N - 1]); + param_half[N - 1] = __hsub(param_half[N - 1], ng_half[N - 1]); } } } @@ -109,7 +109,7 @@ __global__ void FP16MomentumSGDFP32Kernel( bool nesterov, const float wd, half2* param) { -#if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ >= 530 || defined(__HIP_PLATFORM_HCC__) const float lr2 = lr[0]; const float LR = lr2; const float momentum = mom; |