summaryrefslogtreecommitdiff
path: root/src/caffe/solver.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/caffe/solver.cpp')
-rw-r--r--src/caffe/solver.cpp76
1 files changed, 76 insertions, 0 deletions
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 32276ac1..43834c0c 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -859,9 +859,85 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
}
+template <typename Dtype>
+void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+ const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+ const vector<float>& net_params_lr = this->net_->params_lr();
+
+ // get the learning rate
+ Dtype delta = this->param_.delta();
+ Dtype rms_decay = this->param_.rms_decay();
+ Dtype local_rate = rate * net_params_lr[param_id];
+
+ switch (Caffe::mode()) {
+ case Caffe::CPU:
+ // compute square of gradient in update
+ caffe_powx(net_params[param_id]->count(),
+ net_params[param_id]->cpu_diff(), Dtype(2),
+ this->update_[param_id]->mutable_cpu_data());
+
+ // update history
+ caffe_cpu_axpby(net_params[param_id] -> count(),
+ Dtype(1-rms_decay), this->update_[param_id]->cpu_data(),
+ rms_decay, this->history_[param_id]-> mutable_cpu_data());
+
+ // prepare update
+ caffe_powx(net_params[param_id]->count(),
+ this->history_[param_id]->cpu_data(), Dtype(0.5),
+ this->update_[param_id]->mutable_cpu_data());
+
+ caffe_add_scalar(net_params[param_id]->count(),
+ delta, this->update_[param_id]->mutable_cpu_data());
+
+ caffe_div(net_params[param_id]->count(),
+ net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(),
+ this->update_[param_id]->mutable_cpu_data());
+
+ // scale and copy
+ caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+ this->update_[param_id]->cpu_data(), Dtype(0),
+ net_params[param_id]->mutable_cpu_diff());
+ break;
+ case Caffe::GPU:
+#ifndef CPU_ONLY
+ // compute square of gradient in update
+ caffe_gpu_powx(net_params[param_id]->count(),
+ net_params[param_id]->gpu_diff(), Dtype(2),
+ this->update_[param_id]->mutable_gpu_data());
+
+ // update history
+ caffe_gpu_axpby(net_params[param_id] -> count(),
+ Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
+ rms_decay, this->history_[param_id]-> mutable_gpu_data());
+
+ // prepare update
+ caffe_gpu_powx(net_params[param_id]->count(),
+ this->history_[param_id]->gpu_data(), Dtype(0.5),
+ this->update_[param_id]->mutable_gpu_data());
+
+ caffe_gpu_add_scalar(net_params[param_id]->count(),
+ delta, this->update_[param_id]->mutable_gpu_data());
+
+ caffe_gpu_div(net_params[param_id]->count(),
+ net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(),
+ this->update_[param_id]->mutable_gpu_data());
+
+ caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+ this->update_[param_id]->gpu_data(), Dtype(0),
+ net_params[param_id]->mutable_gpu_diff());
+#else
+ NO_GPU;
+#endif
+ break;
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+}
+
INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(SGDSolver);
INSTANTIATE_CLASS(NesterovSolver);
INSTANTIATE_CLASS(AdaGradSolver);
+INSTANTIATE_CLASS(RMSPropSolver);
} // namespace caffe