diff options
author | Przemysław Dolata <snowball91b@gmail.com> | 2018-02-12 09:35:27 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-12 09:35:27 +0100 |
commit | a44c444ee4ae0e7c0aa77118213d34bb26e9f2e6 (patch) | |
tree | e703df252c868b504f2180e46ff8ddae9d53d88d | |
parent | 87e151281d853afdb281e2249620cf839bb932d1 (diff) | |
parent | 6fa4c62dcca954b7f8ae26e7f7314e235dd6a3b4 (diff) | |
download | caffe-a44c444ee4ae0e7c0aa77118213d34bb26e9f2e6.tar.gz caffe-a44c444ee4ae0e7c0aa77118213d34bb26e9f2e6.tar.bz2 caffe-a44c444ee4ae0e7c0aa77118213d34bb26e9f2e6.zip |
Merge pull request #6123 from IlyaOvodov/master
"weights" added to solver parameters, "snapshot_prefix" field default initialization
-rw-r--r-- | src/caffe/proto/caffe.proto | 18 | ||||
-rw-r--r-- | src/caffe/solver.cpp | 21 | ||||
-rw-r--r-- | src/caffe/test/test_upgrade_proto.cpp | 4 | ||||
-rw-r--r-- | src/caffe/util/upgrade_proto.cpp | 21 | ||||
-rw-r--r-- | tools/caffe.cpp | 23 |
5 files changed, 69 insertions, 18 deletions
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index c96966b5..22764abc 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -98,7 +98,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 42 (last added: layer_wise_reduce) +// SolverParameter next available ID: 43 (last added: weights) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -186,7 +186,11 @@ message SolverParameter { optional float clip_gradients = 35 [default = -1]; optional int32 snapshot = 14 [default = 0]; // The snapshot interval - optional string snapshot_prefix = 15; // The prefix for the snapshot. + // The prefix for the snapshot. + // If not set then is replaced by prototxt file path without extention. + // If is set to directory then is augmented by prototxt file name + // without extention. + optional string snapshot_prefix = 15; // whether to snapshot diff in the results or not. Snapshotting diff will help // debugging but the final protocol buffer size will be much larger. optional bool snapshot_diff = 16 [default = false]; @@ -241,6 +245,16 @@ message SolverParameter { // Overlap compute and communication for data parallel training optional bool layer_wise_reduce = 41 [default = true]; + + // Path to caffemodel file(s) with pretrained weights to initialize finetuning. + // Tha same as command line --weights parameter for caffe train command. + // If command line --weights parameter if specified, it has higher priority + // and owerwrites this one(s). + // If --snapshot command line parameter is specified, this one(s) are ignored. + // If several model files are expected, they can be listed in a one + // weights parameter separated by ',' (like in a command string) or + // in repeated weights parameters separately. + repeated string weights = 42; } // A message that stores the solver snapshots diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 04426937..d229acff 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -3,6 +3,7 @@ #include <string> #include <vector> +#include "boost/algorithm/string.hpp" #include "caffe/solver.hpp" #include "caffe/util/format.hpp" #include "caffe/util/hdf5.hpp" @@ -59,6 +60,20 @@ void Solver<Dtype>::Init(const SolverParameter& param) { current_step_ = 0; } +// Load weights from the caffemodel(s) specified in "weights" solver parameter +// into the train and test nets. +template <typename Dtype> +void LoadNetWeights(shared_ptr<Net<Dtype> > net, + const std::string& model_list) { + std::vector<std::string> model_names; + boost::split(model_names, model_list, boost::is_any_of(",")); + for (int i = 0; i < model_names.size(); ++i) { + boost::trim(model_names[i]); + LOG(INFO) << "Finetuning from " << model_names[i]; + net->CopyTrainedLayersFrom(model_names[i]); + } +} + template <typename Dtype> void Solver<Dtype>::InitTrainNet() { const int num_train_nets = param_.has_net() + param_.has_net_param() + @@ -98,6 +113,9 @@ void Solver<Dtype>::InitTrainNet() { net_state.MergeFrom(param_.train_state()); net_param.mutable_state()->CopyFrom(net_state); net_.reset(new Net<Dtype>(net_param)); + for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) { + LoadNetWeights(net_, param_.weights(w_idx)); + } } template <typename Dtype> @@ -173,6 +191,9 @@ void Solver<Dtype>::InitTestNets() { << "Creating test net (#" << i << ") specified by " << sources[i]; test_nets_[i].reset(new Net<Dtype>(net_params[i])); test_nets_[i]->set_debug_info(param_.debug_info()); + for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) { + LoadNetWeights(test_nets_[i], param_.weights(w_idx)); + } } } diff --git a/src/caffe/test/test_upgrade_proto.cpp b/src/caffe/test/test_upgrade_proto.cpp index 9dcc2aa5..769112eb 100644 --- a/src/caffe/test/test_upgrade_proto.cpp +++ b/src/caffe/test/test_upgrade_proto.cpp @@ -2952,6 +2952,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) { for (int i = 0; i < 6; ++i) { const string& input_proto = "net: 'examples/mnist/lenet_train_test.prototxt' " + "weights: 'examples/mnist/lenet_train_test1.caffemodel' " + "weights: 'examples/mnist/lenet_train_test2.caffemodel' " "test_iter: 100 " "test_interval: 500 " "base_lr: 0.01 " @@ -2968,6 +2970,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) { "solver_type: " + std::string(old_type_vec[i]) + " "; const string& expected_output_proto = "net: 'examples/mnist/lenet_train_test.prototxt' " + "weights: 'examples/mnist/lenet_train_test1.caffemodel' " + "weights: 'examples/mnist/lenet_train_test2.caffemodel' " "test_iter: 100 " "test_interval: 500 " "base_lr: 0.01 " diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index 94771c8c..ad40b73d 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -2,6 +2,8 @@ #include <google/protobuf/io/zero_copy_stream_impl.h> #include <google/protobuf/text_format.h> +#include <boost/filesystem.hpp> + #include <map> #include <string> @@ -1095,12 +1097,31 @@ bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) { return success; } +// Replaces snapshot_prefix of SolverParameter if it is not specified +// or is set to directory +void UpgradeSnapshotPrefixProperty(const string& param_file, + SolverParameter* param) { + using boost::filesystem::path; + using boost::filesystem::is_directory; + if (!param->has_snapshot_prefix()) { + param->set_snapshot_prefix(path(param_file).replace_extension().string()); + LOG(INFO) << "snapshot_prefix was not specified and is set to " + + param->snapshot_prefix(); + } else if (is_directory(param->snapshot_prefix())) { + param->set_snapshot_prefix((path(param->snapshot_prefix()) / + path(param_file).stem()).string()); + LOG(INFO) << "snapshot_prefix was a directory and is replaced to " + + param->snapshot_prefix(); + } +} + // Read parameters from a file into a SolverParameter proto message. void ReadSolverParamsFromTextFileOrDie(const string& param_file, SolverParameter* param) { CHECK(ReadProtoFromTextFile(param_file, param)) << "Failed to parse SolverParameter file: " << param_file; UpgradeSolverAsNeeded(param_file, param); + UpgradeSnapshotPrefixProperty(param_file, param); } } // namespace caffe diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 3587d8aa..389cfb8a 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -146,20 +146,6 @@ int device_query() { } RegisterBrewFunction(device_query); -// Load the weights from the specified caffemodel(s) into the train and -// test nets. -void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) { - std::vector<std::string> model_names; - boost::split(model_names, model_list, boost::is_any_of(",") ); - for (int i = 0; i < model_names.size(); ++i) { - LOG(INFO) << "Finetuning from " << model_names[i]; - solver->net()->CopyTrainedLayersFrom(model_names[i]); - for (int j = 0; j < solver->test_nets().size(); ++j) { - solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]); - } - } -} - // Translate the signal effect the user specified on the command-line to the // corresponding enumeration. caffe::SolverAction::Enum GetRequestedAction( @@ -233,6 +219,13 @@ int train() { GetRequestedAction(FLAGS_sigint_effect), GetRequestedAction(FLAGS_sighup_effect)); + if (FLAGS_snapshot.size()) { + solver_param.clear_weights(); + } else if (FLAGS_weights.size()) { + solver_param.clear_weights(); + solver_param.add_weights(FLAGS_weights); + } + shared_ptr<caffe::Solver<float> > solver(caffe::SolverRegistry<float>::CreateSolver(solver_param)); @@ -241,8 +234,6 @@ int train() { if (FLAGS_snapshot.size()) { LOG(INFO) << "Resuming from " << FLAGS_snapshot; solver->Restore(FLAGS_snapshot.c_str()); - } else if (FLAGS_weights.size()) { - CopyLayers(solver.get(), FLAGS_weights); } LOG(INFO) << "Starting Optimization"; |