summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPrzemysław Dolata <snowball91b@gmail.com>2018-02-12 09:35:27 +0100
committerGitHub <noreply@github.com>2018-02-12 09:35:27 +0100
commita44c444ee4ae0e7c0aa77118213d34bb26e9f2e6 (patch)
treee703df252c868b504f2180e46ff8ddae9d53d88d
parent87e151281d853afdb281e2249620cf839bb932d1 (diff)
parent6fa4c62dcca954b7f8ae26e7f7314e235dd6a3b4 (diff)
downloadcaffe-a44c444ee4ae0e7c0aa77118213d34bb26e9f2e6.tar.gz
caffe-a44c444ee4ae0e7c0aa77118213d34bb26e9f2e6.tar.bz2
caffe-a44c444ee4ae0e7c0aa77118213d34bb26e9f2e6.zip
Merge pull request #6123 from IlyaOvodov/master
"weights" added to solver parameters, "snapshot_prefix" field default initialization
-rw-r--r--src/caffe/proto/caffe.proto18
-rw-r--r--src/caffe/solver.cpp21
-rw-r--r--src/caffe/test/test_upgrade_proto.cpp4
-rw-r--r--src/caffe/util/upgrade_proto.cpp21
-rw-r--r--tools/caffe.cpp23
5 files changed, 69 insertions, 18 deletions
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index c96966b5..22764abc 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -98,7 +98,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
-// SolverParameter next available ID: 42 (last added: layer_wise_reduce)
+// SolverParameter next available ID: 43 (last added: weights)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
@@ -186,7 +186,11 @@ message SolverParameter {
optional float clip_gradients = 35 [default = -1];
optional int32 snapshot = 14 [default = 0]; // The snapshot interval
- optional string snapshot_prefix = 15; // The prefix for the snapshot.
+ // The prefix for the snapshot.
+ // If not set then is replaced by prototxt file path without extention.
+ // If is set to directory then is augmented by prototxt file name
+ // without extention.
+ optional string snapshot_prefix = 15;
// whether to snapshot diff in the results or not. Snapshotting diff will help
// debugging but the final protocol buffer size will be much larger.
optional bool snapshot_diff = 16 [default = false];
@@ -241,6 +245,16 @@ message SolverParameter {
// Overlap compute and communication for data parallel training
optional bool layer_wise_reduce = 41 [default = true];
+
+ // Path to caffemodel file(s) with pretrained weights to initialize finetuning.
+ // Tha same as command line --weights parameter for caffe train command.
+ // If command line --weights parameter if specified, it has higher priority
+ // and owerwrites this one(s).
+ // If --snapshot command line parameter is specified, this one(s) are ignored.
+ // If several model files are expected, they can be listed in a one
+ // weights parameter separated by ',' (like in a command string) or
+ // in repeated weights parameters separately.
+ repeated string weights = 42;
}
// A message that stores the solver snapshots
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 04426937..d229acff 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -3,6 +3,7 @@
#include <string>
#include <vector>
+#include "boost/algorithm/string.hpp"
#include "caffe/solver.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/hdf5.hpp"
@@ -59,6 +60,20 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
current_step_ = 0;
}
+// Load weights from the caffemodel(s) specified in "weights" solver parameter
+// into the train and test nets.
+template <typename Dtype>
+void LoadNetWeights(shared_ptr<Net<Dtype> > net,
+ const std::string& model_list) {
+ std::vector<std::string> model_names;
+ boost::split(model_names, model_list, boost::is_any_of(","));
+ for (int i = 0; i < model_names.size(); ++i) {
+ boost::trim(model_names[i]);
+ LOG(INFO) << "Finetuning from " << model_names[i];
+ net->CopyTrainedLayersFrom(model_names[i]);
+ }
+}
+
template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
const int num_train_nets = param_.has_net() + param_.has_net_param() +
@@ -98,6 +113,9 @@ void Solver<Dtype>::InitTrainNet() {
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);
net_.reset(new Net<Dtype>(net_param));
+ for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
+ LoadNetWeights(net_, param_.weights(w_idx));
+ }
}
template <typename Dtype>
@@ -173,6 +191,9 @@ void Solver<Dtype>::InitTestNets() {
<< "Creating test net (#" << i << ") specified by " << sources[i];
test_nets_[i].reset(new Net<Dtype>(net_params[i]));
test_nets_[i]->set_debug_info(param_.debug_info());
+ for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
+ LoadNetWeights(test_nets_[i], param_.weights(w_idx));
+ }
}
}
diff --git a/src/caffe/test/test_upgrade_proto.cpp b/src/caffe/test/test_upgrade_proto.cpp
index 9dcc2aa5..769112eb 100644
--- a/src/caffe/test/test_upgrade_proto.cpp
+++ b/src/caffe/test/test_upgrade_proto.cpp
@@ -2952,6 +2952,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
for (int i = 0; i < 6; ++i) {
const string& input_proto =
"net: 'examples/mnist/lenet_train_test.prototxt' "
+ "weights: 'examples/mnist/lenet_train_test1.caffemodel' "
+ "weights: 'examples/mnist/lenet_train_test2.caffemodel' "
"test_iter: 100 "
"test_interval: 500 "
"base_lr: 0.01 "
@@ -2968,6 +2970,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
"solver_type: " + std::string(old_type_vec[i]) + " ";
const string& expected_output_proto =
"net: 'examples/mnist/lenet_train_test.prototxt' "
+ "weights: 'examples/mnist/lenet_train_test1.caffemodel' "
+ "weights: 'examples/mnist/lenet_train_test2.caffemodel' "
"test_iter: 100 "
"test_interval: 500 "
"base_lr: 0.01 "
diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp
index 94771c8c..ad40b73d 100644
--- a/src/caffe/util/upgrade_proto.cpp
+++ b/src/caffe/util/upgrade_proto.cpp
@@ -2,6 +2,8 @@
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
+#include <boost/filesystem.hpp>
+
#include <map>
#include <string>
@@ -1095,12 +1097,31 @@ bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) {
return success;
}
+// Replaces snapshot_prefix of SolverParameter if it is not specified
+// or is set to directory
+void UpgradeSnapshotPrefixProperty(const string& param_file,
+ SolverParameter* param) {
+ using boost::filesystem::path;
+ using boost::filesystem::is_directory;
+ if (!param->has_snapshot_prefix()) {
+ param->set_snapshot_prefix(path(param_file).replace_extension().string());
+ LOG(INFO) << "snapshot_prefix was not specified and is set to "
+ + param->snapshot_prefix();
+ } else if (is_directory(param->snapshot_prefix())) {
+ param->set_snapshot_prefix((path(param->snapshot_prefix()) /
+ path(param_file).stem()).string());
+ LOG(INFO) << "snapshot_prefix was a directory and is replaced to "
+ + param->snapshot_prefix();
+ }
+}
+
// Read parameters from a file into a SolverParameter proto message.
void ReadSolverParamsFromTextFileOrDie(const string& param_file,
SolverParameter* param) {
CHECK(ReadProtoFromTextFile(param_file, param))
<< "Failed to parse SolverParameter file: " << param_file;
UpgradeSolverAsNeeded(param_file, param);
+ UpgradeSnapshotPrefixProperty(param_file, param);
}
} // namespace caffe
diff --git a/tools/caffe.cpp b/tools/caffe.cpp
index 3587d8aa..389cfb8a 100644
--- a/tools/caffe.cpp
+++ b/tools/caffe.cpp
@@ -146,20 +146,6 @@ int device_query() {
}
RegisterBrewFunction(device_query);
-// Load the weights from the specified caffemodel(s) into the train and
-// test nets.
-void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
- std::vector<std::string> model_names;
- boost::split(model_names, model_list, boost::is_any_of(",") );
- for (int i = 0; i < model_names.size(); ++i) {
- LOG(INFO) << "Finetuning from " << model_names[i];
- solver->net()->CopyTrainedLayersFrom(model_names[i]);
- for (int j = 0; j < solver->test_nets().size(); ++j) {
- solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]);
- }
- }
-}
-
// Translate the signal effect the user specified on the command-line to the
// corresponding enumeration.
caffe::SolverAction::Enum GetRequestedAction(
@@ -233,6 +219,13 @@ int train() {
GetRequestedAction(FLAGS_sigint_effect),
GetRequestedAction(FLAGS_sighup_effect));
+ if (FLAGS_snapshot.size()) {
+ solver_param.clear_weights();
+ } else if (FLAGS_weights.size()) {
+ solver_param.clear_weights();
+ solver_param.add_weights(FLAGS_weights);
+ }
+
shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
@@ -241,8 +234,6 @@ int train() {
if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
solver->Restore(FLAGS_snapshot.c_str());
- } else if (FLAGS_weights.size()) {
- CopyLayers(solver.get(), FLAGS_weights);
}
LOG(INFO) << "Starting Optimization";