summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYaYaB <bezzayassine@gmail.com>2017-12-12 16:16:59 (GMT)
committerYaYaB <bezzayassine@gmail.com>2017-12-12 16:16:59 (GMT)
commitc23b3563f0fa2999578c1a8b3f32dc9cdec5a037 (patch)
tree2af8f158d2c109629b53974d26ac8c53d9bab742
parent99466224dac86ddb86296b1e727794fb836bd80f (diff)
downloadcaffe-c23b3563f0fa2999578c1a8b3f32dc9cdec5a037.zip
caffe-c23b3563f0fa2999578c1a8b3f32dc9cdec5a037.tar.gz
caffe-c23b3563f0fa2999578c1a8b3f32dc9cdec5a037.tar.bz2
Add check values of gamma and stepsize to avoid unexplained core dump
-rw-r--r--src/caffe/solvers/sgd_solver.cpp7
1 files changed, 7 insertions, 0 deletions
diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp
index ad6abe5..1d52beb 100644
--- a/src/caffe/solvers/sgd_solver.cpp
+++ b/src/caffe/solvers/sgd_solver.cpp
@@ -30,12 +30,16 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
if (lr_policy == "fixed") {
rate = this->param_.base_lr();
} else if (lr_policy == "step") {
+ CHECK_GT(this->param_.stepsize(), 0);
this->current_step_ = this->iter_ / this->param_.stepsize();
+ CHECK_GE(this->param_.gamma(), 0);
rate = this->param_.base_lr() *
pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "exp") {
+ CHECK_GE(this->param_.gamma(), 0);
rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
} else if (lr_policy == "inv") {
+ CHECK_GE(this->param_.gamma(), 0);
rate = this->param_.base_lr() *
pow(Dtype(1) + this->param_.gamma() * this->iter_,
- this->param_.power());
@@ -46,6 +50,7 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
LOG(INFO) << "MultiStep Status: Iteration " <<
this->iter_ << ", step = " << this->current_step_;
}
+ CHECK_GE(this->param_.gamma(), 0);
rate = this->param_.base_lr() *
pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "poly") {
@@ -53,6 +58,8 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
(Dtype(this->iter_) / Dtype(this->param_.max_iter())),
this->param_.power());
} else if (lr_policy == "sigmoid") {
+ CHECK_GE(this->param_.gamma(), 0);
+ CHECK_GT(this->param_.stepsize(), 0);
rate = this->param_.base_lr() * (Dtype(1.) /
(Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
Dtype(this->param_.stepsize())))));