diff options
Diffstat (limited to 'tools/caffe.cpp')
-rw-r--r-- | tools/caffe.cpp | 23 |
1 files changed, 7 insertions, 16 deletions
diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 3587d8aa..389cfb8a 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -146,20 +146,6 @@ int device_query() { } RegisterBrewFunction(device_query); -// Load the weights from the specified caffemodel(s) into the train and -// test nets. -void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) { - std::vector<std::string> model_names; - boost::split(model_names, model_list, boost::is_any_of(",") ); - for (int i = 0; i < model_names.size(); ++i) { - LOG(INFO) << "Finetuning from " << model_names[i]; - solver->net()->CopyTrainedLayersFrom(model_names[i]); - for (int j = 0; j < solver->test_nets().size(); ++j) { - solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]); - } - } -} - // Translate the signal effect the user specified on the command-line to the // corresponding enumeration. caffe::SolverAction::Enum GetRequestedAction( @@ -233,6 +219,13 @@ int train() { GetRequestedAction(FLAGS_sigint_effect), GetRequestedAction(FLAGS_sighup_effect)); + if (FLAGS_snapshot.size()) { + solver_param.clear_weights(); + } else if (FLAGS_weights.size()) { + solver_param.clear_weights(); + solver_param.add_weights(FLAGS_weights); + } + shared_ptr<caffe::Solver<float> > solver(caffe::SolverRegistry<float>::CreateSolver(solver_param)); @@ -241,8 +234,6 @@ int train() { if (FLAGS_snapshot.size()) { LOG(INFO) << "Resuming from " << FLAGS_snapshot; solver->Restore(FLAGS_snapshot.c_str()); - } else if (FLAGS_weights.size()) { - CopyLayers(solver.get(), FLAGS_weights); } LOG(INFO) << "Starting Optimization"; |