summaryrefslogtreecommitdiff
path: root/caffe2/sgd
diff options
context:
space:
mode:
authorChenguang Xi <cxi@fb.com>2018-08-31 00:54:05 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-08-31 00:55:39 -0700
commit0555768e0fb4fdbe69f6687c4f6b816b850f1077 (patch)
treec379a7a378ff2e343469eb5dcca9ade7f4c840e0 /caffe2/sgd
parentf1bfe6750f7d52c35a7fe1b9acdbcbca94b7806e (diff)
downloadpytorch-0555768e0fb4fdbe69f6687c4f6b816b850f1077.tar.gz
pytorch-0555768e0fb4fdbe69f6687c4f6b816b850f1077.tar.bz2
pytorch-0555768e0fb4fdbe69f6687c4f6b816b850f1077.zip
Support lr adaption for SparseAdam and RowWiseSparseAdam (#10993)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10993 as title Reviewed By: chocjy Differential Revision: D9554375 fbshipit-source-id: b88768f470ef7d023dd481c6a97b91594892f422
Diffstat (limited to 'caffe2/sgd')
-rw-r--r--caffe2/sgd/adam_op.cc8
-rw-r--r--caffe2/sgd/adam_op.h323
2 files changed, 229 insertions, 102 deletions
diff --git a/caffe2/sgd/adam_op.cc b/caffe2/sgd/adam_op.cc
index 25414622ba..623e93a07e 100644
--- a/caffe2/sgd/adam_op.cc
+++ b/caffe2/sgd/adam_op.cc
@@ -34,7 +34,7 @@ and returns (param_o, m1_o, m2_o, grad_o), in which grad_o is an optional output
.Output(0, "output_param", "Updated parameters")
.Output(1, "output_moment_1", "Updated first moment")
.Output(2, "output_moment_2", "Updated second moment")
- .Output(3, "output_grad", "Effective grad")
+ .Output(3, "output_grad", "Optional Effective gradient")
.Arg("beta1", "Default 0.9")
.Arg("beta2", "Default 0.999")
.Arg("epsilon", "Default 1e-5");
@@ -42,7 +42,7 @@ and returns (param_o, m1_o, m2_o, grad_o), in which grad_o is an optional output
REGISTER_CPU_OPERATOR(SparseAdam, SparseAdamOp<float, CPUContext>);
OPERATOR_SCHEMA(SparseAdam)
.NumInputs(7)
- .NumOutputs(3)
+ .NumOutputs(3, 4)
.EnforceInplace({{0, 0}, {1, 1}, {2, 2}})
.SetDoc(R"DOC(
@@ -62,6 +62,7 @@ OPERATOR_SCHEMA(SparseAdam)
.Output(0, "output_param", "Updated parameters")
.Output(1, "output_moment_1", "Updated first moment")
.Output(2, "output_moment_2", "Updated second moment")
+ .Output(3, "output_grad", "Optional Effective gradient")
.Arg("beta1", "Default 0.9")
.Arg("beta2", "Default 0.999")
.Arg("epsilon", "Default 1e-5");
@@ -71,7 +72,7 @@ REGISTER_CPU_OPERATOR(
RowWiseSparseAdamOp<float, CPUContext>);
OPERATOR_SCHEMA(RowWiseSparseAdam)
.NumInputs(7)
- .NumOutputs(3)
+ .NumOutputs(3, 4)
.EnforceInplace({{0, 0}, {1, 1}, {2, 2}})
.SetDoc(R"DOC(
@@ -95,6 +96,7 @@ OPERATOR_SCHEMA(RowWiseSparseAdam)
.Output(0, "output_param", "Updated parameters")
.Output(1, "output_moment_1", "Updated first moment")
.Output(2, "output_moment_2", "Updated second moment")
+ .Output(3, "output_grad", "Optional Effective gradient")
.Arg("beta1", "Default 0.9")
.Arg("beta2", "Default 0.999")
.Arg("epsilon", "Default 1e-5");
diff --git a/caffe2/sgd/adam_op.h b/caffe2/sgd/adam_op.h
index dadf7f4ee2..3f7bb95556 100644
--- a/caffe2/sgd/adam_op.h
+++ b/caffe2/sgd/adam_op.h
@@ -195,58 +195,118 @@ class SparseAdamOp final : public Operator<Context> {
auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
- for (auto i = 0; i < n; ++i) {
- auto idx = indices[i];
-
- if (block_size == 1) {
- float gi = gradIn[i];
- float mi = moment1Out[idx] =
- moment1In[idx] * beta1_ + gi * (1 - beta1_);
- float vi = moment2Out[idx] =
- moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
- paramOut[idx] =
- paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
-
- } else {
- auto offsetI = i * block_size;
- auto offsetIdx = idx * block_size;
+ if (OutputSize() == 3) {
+ for (auto i = 0; i < n; ++i) {
+ auto idx = indices[i];
+
+ if (block_size == 1) {
+ float gi = gradIn[i];
+ float mi = moment1Out[idx] =
+ moment1In[idx] * beta1_ + gi * (1 - beta1_);
+ float vi = moment2Out[idx] =
+ moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
+ paramOut[idx] = paramIn[idx] +
+ lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
+
+ } else {
+ auto offsetI = i * block_size;
+ auto offsetIdx = idx * block_size;
+
+#ifndef NDEBUG
+ CAFFE_ENFORCE_GE(
+ Input(PARAM).size(),
+ block_size + offsetIdx,
+ this->debug_def().input(PARAM),
+ ", out of bound, idx:",
+ idx,
+ " for input i:",
+ i,
+ " and block size:",
+ block_size);
+ CAFFE_ENFORCE_GE(
+ Input(GRAD).size(),
+ block_size + offsetI,
+ this->debug_def().input(GRAD),
+ ", out of bound idx, idx:",
+ idx,
+ " for input i:",
+ i);
+#endif
+
+ adam_compute(
+ block_size,
+ paramIn + offsetIdx,
+ gradIn + offsetI,
+ moment1In + offsetIdx,
+ moment2In + offsetIdx,
+ paramOut + offsetIdx,
+ moment1Out + offsetIdx,
+ moment2Out + offsetIdx,
+ beta1_,
+ beta2_,
+ epsilon_,
+ correction,
+ lr,
+ &context_);
+ }
+ }
+ } else {
+ Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
+ auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>();
+ for (auto i = 0; i < n; ++i) {
+ auto idx = indices[i];
+
+ if (block_size == 1) {
+ float gi = gradIn[i];
+ float mi = moment1Out[idx] =
+ moment1In[idx] * beta1_ + gi * (1 - beta1_);
+ float vi = moment2Out[idx] =
+ moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
+ float ngi = gradOut[i] = correction * mi / (std::sqrt(vi) + epsilon_);
+ paramOut[idx] = paramIn[idx] + lr[0] * ngi;
+
+ } else {
+ auto offsetI = i * block_size;
+ auto offsetIdx = idx * block_size;
#ifndef NDEBUG
- CAFFE_ENFORCE_GE(
- Input(PARAM).size(),
- block_size + offsetIdx,
- this->debug_def().input(PARAM),
- ", out of bound, idx:",
- idx,
- " for input i:",
- i,
- " and block size:",
- block_size);
- CAFFE_ENFORCE_GE(
- Input(GRAD).size(),
- block_size + offsetI,
- this->debug_def().input(GRAD),
- ", out of bound idx, idx:",
- idx,
- " for input i:",
- i);
+ CAFFE_ENFORCE_GE(
+ Input(PARAM).size(),
+ block_size + offsetIdx,
+ this->debug_def().input(PARAM),
+ ", out of bound, idx:",
+ idx,
+ " for input i:",
+ i,
+ " and block size:",
+ block_size);
+ CAFFE_ENFORCE_GE(
+ Input(GRAD).size(),
+ block_size + offsetI,
+ this->debug_def().input(GRAD),
+ ", out of bound idx, idx:",
+ idx,
+ " for input i:",
+ i);
#endif
- adam_compute(
- block_size,
- paramIn + offsetIdx,
- gradIn + offsetI,
- moment1In + offsetIdx,
- moment2In + offsetIdx,
- paramOut + offsetIdx,
- moment1Out + offsetIdx,
- moment2Out + offsetIdx,
- beta1_,
- beta2_,
- epsilon_,
- correction,
- lr,
- &context_);
+ adam_compute_output_grad(
+ block_size,
+ paramIn + offsetIdx,
+ gradIn + offsetI,
+ moment1In + offsetIdx,
+ moment2In + offsetIdx,
+ paramOut + offsetIdx,
+ moment1Out + offsetIdx,
+ moment2Out + offsetIdx,
+ gradOut + offsetI,
+ beta1_,
+ beta2_,
+ epsilon_,
+ correction,
+ lr,
+ &context_);
+ }
}
}
return true;
@@ -257,7 +317,7 @@ class SparseAdamOp final : public Operator<Context> {
T beta2_;
T epsilon_;
INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
- OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
+ OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD);
};
template <typename T, class Context>
@@ -305,61 +365,126 @@ class RowWiseSparseAdamOp final : public Operator<Context> {
auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
- for (auto i = 0; i < n; ++i) {
- auto idx = indices[i];
+ if (OutputSize() == 3) {
+ for (auto i = 0; i < n; ++i) {
+ auto idx = indices[i];
+
+ if (block_size == 1) {
+ float gi = gradIn[i];
+ float mi = moment1Out[idx] =
+ moment1In[idx] * beta1_ + gi * (1 - beta1_);
+ float vi = moment2Out[idx] =
+ moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
+ paramOut[idx] = paramIn[idx] +
+ lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
+
+ } else {
+ auto offsetI = i * block_size;
+ auto offsetIdx = idx * block_size;
- if (block_size == 1) {
- float gi = gradIn[i];
- float mi = moment1Out[idx] =
- moment1In[idx] * beta1_ + gi * (1 - beta1_);
- float vi = moment2Out[idx] =
- moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
- paramOut[idx] =
- paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
+#ifndef NDEBUG
+ CAFFE_ENFORCE_GE(
+ Input(PARAM).size(),
+ block_size + offsetIdx,
+ this->debug_def().input(PARAM),
+ ", out of bound, idx:",
+ idx,
+ " for input i:",
+ i,
+ " and block size:",
+ block_size);
+ CAFFE_ENFORCE_GE(
+ Input(GRAD).size(),
+ block_size + offsetI,
+ this->debug_def().input(GRAD),
+ ", out of bound idx, idx:",
+ idx,
+ " for input i:",
+ i);
+#endif
- } else {
- auto offsetI = i * block_size;
- auto offsetIdx = idx * block_size;
+ const float* w = paramIn + offsetIdx;
+ const float* g = gradIn + offsetI;
+ const float* m1 = moment1In + offsetIdx;
+ const float* m2 = moment2In + idx;
+ float* nw = paramOut + offsetIdx;
+ float* nm1 = moment1Out + offsetIdx;
+ float* nm2 = moment2Out + idx;
+
+ float m2_sum = 0.;
+ for (auto j = 0; j < block_size; ++j) {
+ float gj = g[j];
+ m2_sum += gj * gj;
+ }
+ float vi = nm2[0] =
+ m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_);
+ for (auto j = 0; j < block_size; ++j) {
+ float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_);
+ nw[j] = w[j] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
+ }
+ }
+ }
+ } else {
+ Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
+ auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>();
+ for (auto i = 0; i < n; ++i) {
+ auto idx = indices[i];
+
+ if (block_size == 1) {
+ float gi = gradIn[i];
+ float mi = moment1Out[idx] =
+ moment1In[idx] * beta1_ + gi * (1 - beta1_);
+ float vi = moment2Out[idx] =
+ moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
+ float ngi = gradOut[i] = correction * mi / (std::sqrt(vi) + epsilon_);
+ paramOut[idx] = paramIn[idx] + lr[0] * ngi;
+
+ } else {
+ auto offsetI = i * block_size;
+ auto offsetIdx = idx * block_size;
#ifndef NDEBUG
- CAFFE_ENFORCE_GE(
- Input(PARAM).size(),
- block_size + offsetIdx,
- this->debug_def().input(PARAM),
- ", out of bound, idx:",
- idx,
- " for input i:",
- i,
- " and block size:",
- block_size);
- CAFFE_ENFORCE_GE(
- Input(GRAD).size(),
- block_size + offsetI,
- this->debug_def().input(GRAD),
- ", out of bound idx, idx:",
- idx,
- " for input i:",
- i);
+ CAFFE_ENFORCE_GE(
+ Input(PARAM).size(),
+ block_size + offsetIdx,
+ this->debug_def().input(PARAM),
+ ", out of bound, idx:",
+ idx,
+ " for input i:",
+ i,
+ " and block size:",
+ block_size);
+ CAFFE_ENFORCE_GE(
+ Input(GRAD).size(),
+ block_size + offsetI,
+ this->debug_def().input(GRAD),
+ ", out of bound idx, idx:",
+ idx,
+ " for input i:",
+ i);
#endif
- const float* w = paramIn + offsetIdx;
- const float* g = gradIn + offsetI;
- const float* m1 = moment1In + offsetIdx;
- const float* m2 = moment2In + idx;
- float* nw = paramOut + offsetIdx;
- float* nm1 = moment1Out + offsetIdx;
- float* nm2 = moment2Out + idx;
-
- float m2_sum = 0.;
- for (auto j = 0; j < block_size; ++j) {
- float gj = g[j];
- m2_sum += gj * gj;
- }
- float vi = nm2[0] =
- m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_);
- for (auto j = 0; j < block_size; ++j) {
- float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_);
- nw[j] = w[j] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
+ const float* w = paramIn + offsetIdx;
+ const float* g = gradIn + offsetI;
+ const float* m1 = moment1In + offsetIdx;
+ const float* m2 = moment2In + idx;
+ float* nw = paramOut + offsetIdx;
+ float* nm1 = moment1Out + offsetIdx;
+ float* nm2 = moment2Out + idx;
+ float* ng = gradOut + offsetI;
+
+ float m2_sum = 0.;
+ for (auto j = 0; j < block_size; ++j) {
+ float gj = g[j];
+ m2_sum += gj * gj;
+ }
+ float vi = nm2[0] =
+ m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_);
+ for (auto j = 0; j < block_size; ++j) {
+ float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_);
+ float ngi = ng[j] = correction * mi / (std::sqrt(vi) + epsilon_);
+ nw[j] = w[j] + lr[0] * ngi;
+ }
}
}
}
@@ -371,7 +496,7 @@ class RowWiseSparseAdamOp final : public Operator<Context> {
T beta2_;
T epsilon_;
INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
- OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
+ OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD);
};
} // namespace caffe2