summaryrefslogtreecommitdiff
path: root/src/caffe/solver.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/caffe/solver.cpp')
-rw-r--r--src/caffe/solver.cpp44
1 files changed, 19 insertions, 25 deletions
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index ece3913e..1c1a9e59 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -26,16 +26,14 @@ SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
}
template <typename Dtype>
-Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
- : net_(), callbacks_(), root_solver_(root_solver),
- requested_early_exit_(false) {
+Solver<Dtype>::Solver(const SolverParameter& param)
+ : net_(), callbacks_(), requested_early_exit_(false) {
Init(param);
}
template <typename Dtype>
-Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
- : net_(), callbacks_(), root_solver_(root_solver),
- requested_early_exit_(false) {
+Solver<Dtype>::Solver(const string& param_file)
+ : net_(), callbacks_(), requested_early_exit_(false) {
SolverParameter param;
ReadSolverParamsFromTextFileOrDie(param_file, &param);
Init(param);
@@ -43,15 +41,13 @@ Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
- CHECK(Caffe::root_solver() || root_solver_)
- << "root_solver_ needs to be set for all non-root solvers";
LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
<< std::endl << param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
CheckSnapshotWritePermissions();
- if (Caffe::root_solver() && param_.random_seed() >= 0) {
- Caffe::set_random_seed(param_.random_seed());
+ if (param_.random_seed() >= 0) {
+ Caffe::set_random_seed(param_.random_seed() + Caffe::solver_rank());
}
// Scaffolding code
InitTrainNet();
@@ -101,11 +97,7 @@ void Solver<Dtype>::InitTrainNet() {
net_state.MergeFrom(net_param.state());
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);
- if (Caffe::root_solver()) {
- net_.reset(new Net<Dtype>(net_param));
- } else {
- net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));
- }
+ net_.reset(new Net<Dtype>(net_param));
}
template <typename Dtype>
@@ -180,12 +172,7 @@ void Solver<Dtype>::InitTestNets() {
net_params[i].mutable_state()->CopyFrom(net_state);
LOG(INFO)
<< "Creating test net (#" << i << ") specified by " << sources[i];
- if (Caffe::root_solver()) {
- test_nets_[i].reset(new Net<Dtype>(net_params[i]));
- } else {
- test_nets_[i].reset(new Net<Dtype>(net_params[i],
- root_solver_->test_nets_[i].get()));
- }
+ test_nets_[i].reset(new Net<Dtype>(net_params[i]));
test_nets_[i]->set_debug_info(param_.debug_info());
}
}
@@ -197,14 +184,16 @@ void Solver<Dtype>::Step(int iters) {
int average_loss = this->param_.average_loss();
losses_.clear();
smoothed_loss_ = 0;
+ iteration_timer_.Start();
while (iter_ < stop_iter) {
// zero-init the params
net_->ClearParamDiffs();
if (param_.test_interval() && iter_ % param_.test_interval() == 0
- && (iter_ > 0 || param_.test_initialization())
- && Caffe::root_solver()) {
- TestAll();
+ && (iter_ > 0 || param_.test_initialization())) {
+ if (Caffe::root_solver()) {
+ TestAll();
+ }
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
@@ -225,8 +214,13 @@ void Solver<Dtype>::Step(int iters) {
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss);
if (display) {
+ float lapse = iteration_timer_.Seconds();
+ float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1);
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
- << ", loss = " << smoothed_loss_;
+ << " (" << per_s << " iter/s, " << lapse << "s/"
+ << param_.display() << " iters), loss = " << smoothed_loss_;
+ iteration_timer_.Start();
+ iterations_last_ = iter_;
const vector<Blob<Dtype>*>& result = net_->output_blobs();
int score_index = 0;
for (int j = 0; j < result.size(); ++j) {