summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorrohithkrn <rohith.nallamaddi@gmail.com>2018-12-14 16:31:34 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-12-14 16:33:45 -0800
commit763b9954f3fbe0b2058eee9fe8055dfdc58ee615 (patch)
tree44dd4ac008bcf9beb956ab286df615bccccabd66
parente596d2313761ce0c93f2f6d379344aa22620b06b (diff)
downloadpytorch-763b9954f3fbe0b2058eee9fe8055dfdc58ee615.tar.gz
pytorch-763b9954f3fbe0b2058eee9fe8055dfdc58ee615.tar.bz2
pytorch-763b9954f3fbe0b2058eee9fe8055dfdc58ee615.zip
FP16MomentumSGDUpdate Op fix and enable for ROCm (#15150)
Summary: 1. Fix a bug in FP16MomentumSGDUpdate operator 2. Enable operator for ROCm Pull Request resolved: https://github.com/pytorch/pytorch/pull/15150 Differential Revision: D13473145 Pulled By: bddppq fbshipit-source-id: 4c5c5f30cb9bba658e3639dbe193fa08a304d306
-rw-r--r--caffe2/python/operator_test/momentum_sgd_test.py8
-rw-r--r--caffe2/sgd/fp16_momentum_sgd_op.cu6
2 files changed, 8 insertions, 6 deletions
diff --git a/caffe2/python/operator_test/momentum_sgd_test.py b/caffe2/python/operator_test/momentum_sgd_test.py
index 27dcb78c14..bcd0631951 100644
--- a/caffe2/python/operator_test/momentum_sgd_test.py
+++ b/caffe2/python/operator_test/momentum_sgd_test.py
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
+from caffe2.proto import caffe2_pb2
from caffe2.python import core, workspace
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial
@@ -143,7 +144,7 @@ class TestMomentumSGD(serial.SerializedTestCase):
def test_fp16momentum_sgd(self, n, nesterov, gc, dc):
assume(core.IsGPUDeviceType(gc.device_type))
gpuvers = workspace.GetDeviceProperties(0)["major"]
- if gpuvers < 6:
+ if gc.device_type == caffe2_pb2.CUDA and gpuvers < 6:
print("No FP16 support because major version {} < 6".format(gpuvers))
return
@@ -152,7 +153,6 @@ class TestMomentumSGD(serial.SerializedTestCase):
lr = np.random.rand(1).astype(np.float32)
param_momentum = np.random.rand(n).astype(np.float16)
momentum = 0.9
- nesterov = True
def momentum_sgd(grad, param_momentum, lr, param=None):
if not nesterov:
@@ -174,11 +174,13 @@ class TestMomentumSGD(serial.SerializedTestCase):
weight_decay=0.0,
)
+ threshold = 1e-3 if (gc.device_type == caffe2_pb2.HIP) else 1e-4
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[grad, param_momentum, lr, param],
- reference=momentum_sgd
+ reference=momentum_sgd,
+ threshold=threshold
)
diff --git a/caffe2/sgd/fp16_momentum_sgd_op.cu b/caffe2/sgd/fp16_momentum_sgd_op.cu
index b7ac0a7b76..8ec1c85fd5 100644
--- a/caffe2/sgd/fp16_momentum_sgd_op.cu
+++ b/caffe2/sgd/fp16_momentum_sgd_op.cu
@@ -22,7 +22,7 @@ __global__ void FP16MomentumSGDKernel(
bool nesterov,
const float wd,
half2* param) {
-#if __CUDA_ARCH__ >= 530
+#if __CUDA_ARCH__ >= 530 || defined(__HIP_PLATFORM_HCC__)
const float lr2 = lr[0];
const half2 LR = __float2half2_rn(lr2);
const half2 momentum = __float2half2_rn(mom);
@@ -87,7 +87,7 @@ __global__ void FP16MomentumSGDKernel(
__hfma(mi_new_half, __high2half(momentum), mi_new_half),
mom_mi_half);
if (param) {
- param_half[N - 1] = __hsub(param_half[i], ng_half[N - 1]);
+ param_half[N - 1] = __hsub(param_half[N - 1], ng_half[N - 1]);
}
}
}
@@ -109,7 +109,7 @@ __global__ void FP16MomentumSGDFP32Kernel(
bool nesterov,
const float wd,
half2* param) {
-#if __CUDA_ARCH__ >= 530
+#if __CUDA_ARCH__ >= 530 || defined(__HIP_PLATFORM_HCC__)
const float lr2 = lr[0];
const float LR = lr2;
const float momentum = mom;