summaryrefslogtreecommitdiff
path: root/tools/caffe.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tools/caffe.cpp')
-rw-r--r--tools/caffe.cpp23
1 files changed, 7 insertions, 16 deletions
diff --git a/tools/caffe.cpp b/tools/caffe.cpp
index 3587d8aa..389cfb8a 100644
--- a/tools/caffe.cpp
+++ b/tools/caffe.cpp
@@ -146,20 +146,6 @@ int device_query() {
}
RegisterBrewFunction(device_query);
-// Load the weights from the specified caffemodel(s) into the train and
-// test nets.
-void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
- std::vector<std::string> model_names;
- boost::split(model_names, model_list, boost::is_any_of(",") );
- for (int i = 0; i < model_names.size(); ++i) {
- LOG(INFO) << "Finetuning from " << model_names[i];
- solver->net()->CopyTrainedLayersFrom(model_names[i]);
- for (int j = 0; j < solver->test_nets().size(); ++j) {
- solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]);
- }
- }
-}
-
// Translate the signal effect the user specified on the command-line to the
// corresponding enumeration.
caffe::SolverAction::Enum GetRequestedAction(
@@ -233,6 +219,13 @@ int train() {
GetRequestedAction(FLAGS_sigint_effect),
GetRequestedAction(FLAGS_sighup_effect));
+ if (FLAGS_snapshot.size()) {
+ solver_param.clear_weights();
+ } else if (FLAGS_weights.size()) {
+ solver_param.clear_weights();
+ solver_param.add_weights(FLAGS_weights);
+ }
+
shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
@@ -241,8 +234,6 @@ int train() {
if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
solver->Restore(FLAGS_snapshot.c_str());
- } else if (FLAGS_weights.size()) {
- CopyLayers(solver.get(), FLAGS_weights);
}
LOG(INFO) << "Starting Optimization";