diff options
Diffstat (limited to 'src/caffe/solver.cpp')
-rw-r--r-- | src/caffe/solver.cpp | 44 |
1 files changed, 19 insertions, 25 deletions
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index ece3913e..1c1a9e59 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -26,16 +26,14 @@ SolverAction::Enum Solver<Dtype>::GetRequestedAction() { } template <typename Dtype> -Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver) - : net_(), callbacks_(), root_solver_(root_solver), - requested_early_exit_(false) { +Solver<Dtype>::Solver(const SolverParameter& param) + : net_(), callbacks_(), requested_early_exit_(false) { Init(param); } template <typename Dtype> -Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver) - : net_(), callbacks_(), root_solver_(root_solver), - requested_early_exit_(false) { +Solver<Dtype>::Solver(const string& param_file) + : net_(), callbacks_(), requested_early_exit_(false) { SolverParameter param; ReadSolverParamsFromTextFileOrDie(param_file, ¶m); Init(param); @@ -43,15 +41,13 @@ Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver) template <typename Dtype> void Solver<Dtype>::Init(const SolverParameter& param) { - CHECK(Caffe::root_solver() || root_solver_) - << "root_solver_ needs to be set for all non-root solvers"; LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: " << std::endl << param.DebugString(); param_ = param; CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative."; CheckSnapshotWritePermissions(); - if (Caffe::root_solver() && param_.random_seed() >= 0) { - Caffe::set_random_seed(param_.random_seed()); + if (param_.random_seed() >= 0) { + Caffe::set_random_seed(param_.random_seed() + Caffe::solver_rank()); } // Scaffolding code InitTrainNet(); @@ -101,11 +97,7 @@ void Solver<Dtype>::InitTrainNet() { net_state.MergeFrom(net_param.state()); net_state.MergeFrom(param_.train_state()); net_param.mutable_state()->CopyFrom(net_state); - if (Caffe::root_solver()) { - net_.reset(new Net<Dtype>(net_param)); - } else { - net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get())); - } + net_.reset(new Net<Dtype>(net_param)); } template <typename Dtype> @@ -180,12 +172,7 @@ void Solver<Dtype>::InitTestNets() { net_params[i].mutable_state()->CopyFrom(net_state); LOG(INFO) << "Creating test net (#" << i << ") specified by " << sources[i]; - if (Caffe::root_solver()) { - test_nets_[i].reset(new Net<Dtype>(net_params[i])); - } else { - test_nets_[i].reset(new Net<Dtype>(net_params[i], - root_solver_->test_nets_[i].get())); - } + test_nets_[i].reset(new Net<Dtype>(net_params[i])); test_nets_[i]->set_debug_info(param_.debug_info()); } } @@ -197,14 +184,16 @@ void Solver<Dtype>::Step(int iters) { int average_loss = this->param_.average_loss(); losses_.clear(); smoothed_loss_ = 0; + iteration_timer_.Start(); while (iter_ < stop_iter) { // zero-init the params net_->ClearParamDiffs(); if (param_.test_interval() && iter_ % param_.test_interval() == 0 - && (iter_ > 0 || param_.test_initialization()) - && Caffe::root_solver()) { - TestAll(); + && (iter_ > 0 || param_.test_initialization())) { + if (Caffe::root_solver()) { + TestAll(); + } if (requested_early_exit_) { // Break out of the while loop because stop was requested while testing. break; @@ -225,8 +214,13 @@ void Solver<Dtype>::Step(int iters) { // average the loss across iterations for smoothed reporting UpdateSmoothedLoss(loss, start_iter, average_loss); if (display) { + float lapse = iteration_timer_.Seconds(); + float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1); LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_ - << ", loss = " << smoothed_loss_; + << " (" << per_s << " iter/s, " << lapse << "s/" + << param_.display() << " iters), loss = " << smoothed_loss_; + iteration_timer_.Start(); + iterations_last_ = iter_; const vector<Blob<Dtype>*>& result = net_->output_blobs(); int score_index = 0; for (int j = 0; j < result.size(); ++j) { |