summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorJ Yegerlehner <jyegerlehner@yahoo.com>2014-11-19 17:30:46 -0600
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2015-03-07 18:37:33 -0800
commita0087e49928005d041f1694460b4a43ff30722f8 (patch)
tree326ed4dec6aa4161c2e37efa8b6c5cd9c0b0df55 /tools
parentc3aee354d8831e6fc7e7633a41c5e403ef0bb6b8 (diff)
downloadcaffeonacl-a0087e49928005d041f1694460b4a43ff30722f8.tar.gz
caffeonacl-a0087e49928005d041f1694460b4a43ff30722f8.tar.bz2
caffeonacl-a0087e49928005d041f1694460b4a43ff30722f8.zip
Load weights from multiple caffemodels.
Diffstat (limited to 'tools')
-rw-r--r--tools/caffe.cpp17
1 files changed, 15 insertions, 2 deletions
diff --git a/tools/caffe.cpp b/tools/caffe.cpp
index f04e28a3..eb9e97f5 100644
--- a/tools/caffe.cpp
+++ b/tools/caffe.cpp
@@ -5,6 +5,7 @@
#include <string>
#include <vector>
+#include "boost/algorithm/string.hpp"
#include "caffe/caffe.hpp"
using caffe::Blob;
@@ -76,6 +77,19 @@ int device_query() {
}
RegisterBrewFunction(device_query);
+// Load the weights from the specified caffemodel(s) into the train and
+// test nets.
+void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
+ std::vector<std::string> model_names;
+ boost::split(model_names, model_list, boost::is_any_of(",") );
+ for (int i = 0; i < model_names.size(); ++i) {
+ LOG(INFO) << "Finetuning from " << model_names[i];
+ solver->net()->CopyTrainedLayersFrom(model_names[i]);
+ for (int j = 0; j < solver->test_nets().size(); ++j) {
+ solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]);
+ }
+ }
+}
// Train / Finetune a model.
int train() {
@@ -112,8 +126,7 @@ int train() {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
solver->Solve(FLAGS_snapshot);
} else if (FLAGS_weights.size()) {
- LOG(INFO) << "Finetuning from " << FLAGS_weights;
- solver->net()->CopyTrainedLayersFrom(FLAGS_weights);
+ CopyLayers(&*solver, FLAGS_weights);
solver->Solve();
} else {
solver->Solve();