diff options
author | Chenguang Xi <cxi@fb.com> | 2018-08-31 00:54:05 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-08-31 00:55:39 -0700 |
commit | 0555768e0fb4fdbe69f6687c4f6b816b850f1077 (patch) | |
tree | c379a7a378ff2e343469eb5dcca9ade7f4c840e0 /caffe2/sgd | |
parent | f1bfe6750f7d52c35a7fe1b9acdbcbca94b7806e (diff) | |
download | pytorch-0555768e0fb4fdbe69f6687c4f6b816b850f1077.tar.gz pytorch-0555768e0fb4fdbe69f6687c4f6b816b850f1077.tar.bz2 pytorch-0555768e0fb4fdbe69f6687c4f6b816b850f1077.zip |
Support lr adaption for SparseAdam and RowWiseSparseAdam (#10993)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10993
as title
Reviewed By: chocjy
Differential Revision: D9554375
fbshipit-source-id: b88768f470ef7d023dd481c6a97b91594892f422
Diffstat (limited to 'caffe2/sgd')
-rw-r--r-- | caffe2/sgd/adam_op.cc | 8 | ||||
-rw-r--r-- | caffe2/sgd/adam_op.h | 323 |
2 files changed, 229 insertions, 102 deletions
diff --git a/caffe2/sgd/adam_op.cc b/caffe2/sgd/adam_op.cc index 25414622ba..623e93a07e 100644 --- a/caffe2/sgd/adam_op.cc +++ b/caffe2/sgd/adam_op.cc @@ -34,7 +34,7 @@ and returns (param_o, m1_o, m2_o, grad_o), in which grad_o is an optional output .Output(0, "output_param", "Updated parameters") .Output(1, "output_moment_1", "Updated first moment") .Output(2, "output_moment_2", "Updated second moment") - .Output(3, "output_grad", "Effective grad") + .Output(3, "output_grad", "Optional Effective gradient") .Arg("beta1", "Default 0.9") .Arg("beta2", "Default 0.999") .Arg("epsilon", "Default 1e-5"); @@ -42,7 +42,7 @@ and returns (param_o, m1_o, m2_o, grad_o), in which grad_o is an optional output REGISTER_CPU_OPERATOR(SparseAdam, SparseAdamOp<float, CPUContext>); OPERATOR_SCHEMA(SparseAdam) .NumInputs(7) - .NumOutputs(3) + .NumOutputs(3, 4) .EnforceInplace({{0, 0}, {1, 1}, {2, 2}}) .SetDoc(R"DOC( @@ -62,6 +62,7 @@ OPERATOR_SCHEMA(SparseAdam) .Output(0, "output_param", "Updated parameters") .Output(1, "output_moment_1", "Updated first moment") .Output(2, "output_moment_2", "Updated second moment") + .Output(3, "output_grad", "Optional Effective gradient") .Arg("beta1", "Default 0.9") .Arg("beta2", "Default 0.999") .Arg("epsilon", "Default 1e-5"); @@ -71,7 +72,7 @@ REGISTER_CPU_OPERATOR( RowWiseSparseAdamOp<float, CPUContext>); OPERATOR_SCHEMA(RowWiseSparseAdam) .NumInputs(7) - .NumOutputs(3) + .NumOutputs(3, 4) .EnforceInplace({{0, 0}, {1, 1}, {2, 2}}) .SetDoc(R"DOC( @@ -95,6 +96,7 @@ OPERATOR_SCHEMA(RowWiseSparseAdam) .Output(0, "output_param", "Updated parameters") .Output(1, "output_moment_1", "Updated first moment") .Output(2, "output_moment_2", "Updated second moment") + .Output(3, "output_grad", "Optional Effective gradient") .Arg("beta1", "Default 0.9") .Arg("beta2", "Default 0.999") .Arg("epsilon", "Default 1e-5"); diff --git a/caffe2/sgd/adam_op.h b/caffe2/sgd/adam_op.h index dadf7f4ee2..3f7bb95556 100644 --- a/caffe2/sgd/adam_op.h +++ b/caffe2/sgd/adam_op.h @@ -195,58 +195,118 @@ class SparseAdamOp final : public Operator<Context> { auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>(); auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>(); - for (auto i = 0; i < n; ++i) { - auto idx = indices[i]; - - if (block_size == 1) { - float gi = gradIn[i]; - float mi = moment1Out[idx] = - moment1In[idx] * beta1_ + gi * (1 - beta1_); - float vi = moment2Out[idx] = - moment2In[idx] * beta2_ + gi * gi * (1 - beta2_); - paramOut[idx] = - paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_); - - } else { - auto offsetI = i * block_size; - auto offsetIdx = idx * block_size; + if (OutputSize() == 3) { + for (auto i = 0; i < n; ++i) { + auto idx = indices[i]; + + if (block_size == 1) { + float gi = gradIn[i]; + float mi = moment1Out[idx] = + moment1In[idx] * beta1_ + gi * (1 - beta1_); + float vi = moment2Out[idx] = + moment2In[idx] * beta2_ + gi * gi * (1 - beta2_); + paramOut[idx] = paramIn[idx] + + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_); + + } else { + auto offsetI = i * block_size; + auto offsetIdx = idx * block_size; + +#ifndef NDEBUG + CAFFE_ENFORCE_GE( + Input(PARAM).size(), + block_size + offsetIdx, + this->debug_def().input(PARAM), + ", out of bound, idx:", + idx, + " for input i:", + i, + " and block size:", + block_size); + CAFFE_ENFORCE_GE( + Input(GRAD).size(), + block_size + offsetI, + this->debug_def().input(GRAD), + ", out of bound idx, idx:", + idx, + " for input i:", + i); +#endif + + adam_compute( + block_size, + paramIn + offsetIdx, + gradIn + offsetI, + moment1In + offsetIdx, + moment2In + offsetIdx, + paramOut + offsetIdx, + moment1Out + offsetIdx, + moment2Out + offsetIdx, + beta1_, + beta2_, + epsilon_, + correction, + lr, + &context_); + } + } + } else { + Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); + auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>(); + for (auto i = 0; i < n; ++i) { + auto idx = indices[i]; + + if (block_size == 1) { + float gi = gradIn[i]; + float mi = moment1Out[idx] = + moment1In[idx] * beta1_ + gi * (1 - beta1_); + float vi = moment2Out[idx] = + moment2In[idx] * beta2_ + gi * gi * (1 - beta2_); + float ngi = gradOut[i] = correction * mi / (std::sqrt(vi) + epsilon_); + paramOut[idx] = paramIn[idx] + lr[0] * ngi; + + } else { + auto offsetI = i * block_size; + auto offsetIdx = idx * block_size; #ifndef NDEBUG - CAFFE_ENFORCE_GE( - Input(PARAM).size(), - block_size + offsetIdx, - this->debug_def().input(PARAM), - ", out of bound, idx:", - idx, - " for input i:", - i, - " and block size:", - block_size); - CAFFE_ENFORCE_GE( - Input(GRAD).size(), - block_size + offsetI, - this->debug_def().input(GRAD), - ", out of bound idx, idx:", - idx, - " for input i:", - i); + CAFFE_ENFORCE_GE( + Input(PARAM).size(), + block_size + offsetIdx, + this->debug_def().input(PARAM), + ", out of bound, idx:", + idx, + " for input i:", + i, + " and block size:", + block_size); + CAFFE_ENFORCE_GE( + Input(GRAD).size(), + block_size + offsetI, + this->debug_def().input(GRAD), + ", out of bound idx, idx:", + idx, + " for input i:", + i); #endif - adam_compute( - block_size, - paramIn + offsetIdx, - gradIn + offsetI, - moment1In + offsetIdx, - moment2In + offsetIdx, - paramOut + offsetIdx, - moment1Out + offsetIdx, - moment2Out + offsetIdx, - beta1_, - beta2_, - epsilon_, - correction, - lr, - &context_); + adam_compute_output_grad( + block_size, + paramIn + offsetIdx, + gradIn + offsetI, + moment1In + offsetIdx, + moment2In + offsetIdx, + paramOut + offsetIdx, + moment1Out + offsetIdx, + moment2Out + offsetIdx, + gradOut + offsetI, + beta1_, + beta2_, + epsilon_, + correction, + lr, + &context_); + } } } return true; @@ -257,7 +317,7 @@ class SparseAdamOp final : public Operator<Context> { T beta2_; T epsilon_; INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER); - OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2); + OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD); }; template <typename T, class Context> @@ -305,61 +365,126 @@ class RowWiseSparseAdamOp final : public Operator<Context> { auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>(); auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>(); - for (auto i = 0; i < n; ++i) { - auto idx = indices[i]; + if (OutputSize() == 3) { + for (auto i = 0; i < n; ++i) { + auto idx = indices[i]; + + if (block_size == 1) { + float gi = gradIn[i]; + float mi = moment1Out[idx] = + moment1In[idx] * beta1_ + gi * (1 - beta1_); + float vi = moment2Out[idx] = + moment2In[idx] * beta2_ + gi * gi * (1 - beta2_); + paramOut[idx] = paramIn[idx] + + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_); + + } else { + auto offsetI = i * block_size; + auto offsetIdx = idx * block_size; - if (block_size == 1) { - float gi = gradIn[i]; - float mi = moment1Out[idx] = - moment1In[idx] * beta1_ + gi * (1 - beta1_); - float vi = moment2Out[idx] = - moment2In[idx] * beta2_ + gi * gi * (1 - beta2_); - paramOut[idx] = - paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_); +#ifndef NDEBUG + CAFFE_ENFORCE_GE( + Input(PARAM).size(), + block_size + offsetIdx, + this->debug_def().input(PARAM), + ", out of bound, idx:", + idx, + " for input i:", + i, + " and block size:", + block_size); + CAFFE_ENFORCE_GE( + Input(GRAD).size(), + block_size + offsetI, + this->debug_def().input(GRAD), + ", out of bound idx, idx:", + idx, + " for input i:", + i); +#endif - } else { - auto offsetI = i * block_size; - auto offsetIdx = idx * block_size; + const float* w = paramIn + offsetIdx; + const float* g = gradIn + offsetI; + const float* m1 = moment1In + offsetIdx; + const float* m2 = moment2In + idx; + float* nw = paramOut + offsetIdx; + float* nm1 = moment1Out + offsetIdx; + float* nm2 = moment2Out + idx; + + float m2_sum = 0.; + for (auto j = 0; j < block_size; ++j) { + float gj = g[j]; + m2_sum += gj * gj; + } + float vi = nm2[0] = + m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_); + for (auto j = 0; j < block_size; ++j) { + float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_); + nw[j] = w[j] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_); + } + } + } + } else { + Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); + auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>(); + for (auto i = 0; i < n; ++i) { + auto idx = indices[i]; + + if (block_size == 1) { + float gi = gradIn[i]; + float mi = moment1Out[idx] = + moment1In[idx] * beta1_ + gi * (1 - beta1_); + float vi = moment2Out[idx] = + moment2In[idx] * beta2_ + gi * gi * (1 - beta2_); + float ngi = gradOut[i] = correction * mi / (std::sqrt(vi) + epsilon_); + paramOut[idx] = paramIn[idx] + lr[0] * ngi; + + } else { + auto offsetI = i * block_size; + auto offsetIdx = idx * block_size; #ifndef NDEBUG - CAFFE_ENFORCE_GE( - Input(PARAM).size(), - block_size + offsetIdx, - this->debug_def().input(PARAM), - ", out of bound, idx:", - idx, - " for input i:", - i, - " and block size:", - block_size); - CAFFE_ENFORCE_GE( - Input(GRAD).size(), - block_size + offsetI, - this->debug_def().input(GRAD), - ", out of bound idx, idx:", - idx, - " for input i:", - i); + CAFFE_ENFORCE_GE( + Input(PARAM).size(), + block_size + offsetIdx, + this->debug_def().input(PARAM), + ", out of bound, idx:", + idx, + " for input i:", + i, + " and block size:", + block_size); + CAFFE_ENFORCE_GE( + Input(GRAD).size(), + block_size + offsetI, + this->debug_def().input(GRAD), + ", out of bound idx, idx:", + idx, + " for input i:", + i); #endif - const float* w = paramIn + offsetIdx; - const float* g = gradIn + offsetI; - const float* m1 = moment1In + offsetIdx; - const float* m2 = moment2In + idx; - float* nw = paramOut + offsetIdx; - float* nm1 = moment1Out + offsetIdx; - float* nm2 = moment2Out + idx; - - float m2_sum = 0.; - for (auto j = 0; j < block_size; ++j) { - float gj = g[j]; - m2_sum += gj * gj; - } - float vi = nm2[0] = - m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_); - for (auto j = 0; j < block_size; ++j) { - float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_); - nw[j] = w[j] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_); + const float* w = paramIn + offsetIdx; + const float* g = gradIn + offsetI; + const float* m1 = moment1In + offsetIdx; + const float* m2 = moment2In + idx; + float* nw = paramOut + offsetIdx; + float* nm1 = moment1Out + offsetIdx; + float* nm2 = moment2Out + idx; + float* ng = gradOut + offsetI; + + float m2_sum = 0.; + for (auto j = 0; j < block_size; ++j) { + float gj = g[j]; + m2_sum += gj * gj; + } + float vi = nm2[0] = + m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_); + for (auto j = 0; j < block_size; ++j) { + float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_); + float ngi = ng[j] = correction * mi / (std::sqrt(vi) + epsilon_); + nw[j] = w[j] + lr[0] * ngi; + } } } } @@ -371,7 +496,7 @@ class RowWiseSparseAdamOp final : public Operator<Context> { T beta2_; T epsilon_; INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER); - OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2); + OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD); }; } // namespace caffe2 |