summaryrefslogtreecommitdiff
path: root/tools/caffe.cpp
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-08-06 23:22:13 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-08-06 23:44:55 -0700
commit8634885d83dabfb38e8ca5f0d0d037335127cfba (patch)
treee9048e29573b7e82c79334e7a56deb1e06418c9d /tools/caffe.cpp
parent3655aa93952c470b2fc04ca622b03f7a956be9cb (diff)
downloadcaffeonacl-8634885d83dabfb38e8ca5f0d0d037335127cfba.tar.gz
caffeonacl-8634885d83dabfb38e8ca5f0d0d037335127cfba.tar.bz2
caffeonacl-8634885d83dabfb38e8ca5f0d0d037335127cfba.zip
rename caffe cli args and revise text
Diffstat (limited to 'tools/caffe.cpp')
-rw-r--r--tools/caffe.cpp52
1 files changed, 25 insertions, 27 deletions
diff --git a/tools/caffe.cpp b/tools/caffe.cpp
index 3bf18650..1320b7a5 100644
--- a/tools/caffe.cpp
+++ b/tools/caffe.cpp
@@ -17,23 +17,21 @@ using caffe::Timer;
using caffe::vector;
-// Used in device query
DEFINE_int32(device_id, 0,
- "[device_query,time] The device id to use.");
-// Used in training
-DEFINE_string(solver_proto_file, "",
- "[train] The protobuf containing the solver definition.");
-DEFINE_string(net_proto_file, "",
- "[time] The net proto file to use.");
-DEFINE_string(resume_point_file, "",
- "[train] (optional) The snapshot from which to resume training.");
-DEFINE_string(pretrained_net_file, "",
- "[train] (optional) A pretrained network to run finetune from. "
- "Cannot be set simultaneously with resume_point_file.");
-DEFINE_int32(run_iterations, 50,
- "[time] The number of iterations to run.");
+ "The GPU device ID to use.");
DEFINE_bool(gpu, false,
- "[time] Run in GPU mode when true.");
+ "Run in GPU mode when true.");
+DEFINE_string(solver, "",
+ "The solver definition protocol buffer text file.");
+DEFINE_string(model, "",
+ "The model definition protocol buffer text file..");
+DEFINE_string(snapshot, "",
+ "Optional; the snapshot solver state to resume training.");
+DEFINE_string(weights, "",
+ "Optional; the pretrained weights to initialize finetuning. "
+ "Cannot be set simultaneously with snapshot.");
+DEFINE_int32(iterations, 50,
+ "The number of iterations to run.");
// A simple registry for caffe commands.
typedef int (*BrewFunction)();
@@ -79,19 +77,19 @@ int device_query() {
RegisterBrewFunction(device_query);
int train() {
- CHECK_GT(FLAGS_solver_proto_file.size(), 0);
+ CHECK_GT(FLAGS_solver.size(), 0);
caffe::SolverParameter solver_param;
- caffe::ReadProtoFromTextFileOrDie(FLAGS_solver_proto_file, &solver_param);
+ caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param);
LOG(INFO) << "Starting Optimization";
caffe::SGDSolver<float> solver(solver_param);
- if (FLAGS_resume_point_file.size()) {
- LOG(INFO) << "Resuming from " << FLAGS_resume_point_file;
- solver.Solve(FLAGS_resume_point_file);
- } else if (FLAGS_pretrained_net_file.size()) {
- LOG(INFO) << "Finetuning from " << FLAGS_pretrained_net_file;
- solver.net()->CopyTrainedLayersFrom(FLAGS_pretrained_net_file);
+ if (FLAGS_snapshot.size()) {
+ LOG(INFO) << "Resuming from " << FLAGS_snapshot;
+ solver.Solve(FLAGS_snapshot);
+ } else if (FLAGS_weights.size()) {
+ LOG(INFO) << "Finetuning from " << FLAGS_weights;
+ solver.net()->CopyTrainedLayersFrom(FLAGS_weights);
solver.Solve();
} else {
solver.Solve();
@@ -113,7 +111,7 @@ int time() {
}
// Instantiate the caffe net.
Caffe::set_phase(Caffe::TRAIN);
- Net<float> caffe_net(FLAGS_net_proto_file);
+ Net<float> caffe_net(FLAGS_model);
// Do a clean forward and backward pass, so that memory allocation are done
// and future iterations will be more stable.
@@ -132,7 +130,7 @@ int time() {
const vector<vector<bool> >& bottom_need_backward =
caffe_net.bottom_need_backward();
LOG(INFO) << "*** Benchmark begins ***";
- LOG(INFO) << "Testing for " << FLAGS_run_iterations << " iterations.";
+ LOG(INFO) << "Testing for " << FLAGS_iterations << " iterations.";
Timer total_timer;
total_timer.Start();
Timer forward_timer;
@@ -141,7 +139,7 @@ int time() {
for (int i = 0; i < layers.size(); ++i) {
const caffe::string& layername = layers[i]->layer_param().name();
timer.Start();
- for (int j = 0; j < FLAGS_run_iterations; ++j) {
+ for (int j = 0; j < FLAGS_iterations; ++j) {
layers[i]->Forward(bottom_vecs[i], &top_vecs[i]);
}
LOG(INFO) << layername << "\tforward: " << timer.MilliSeconds() <<
@@ -154,7 +152,7 @@ int time() {
for (int i = layers.size() - 1; i >= 0; --i) {
const caffe::string& layername = layers[i]->layer_param().name();
timer.Start();
- for (int j = 0; j < FLAGS_run_iterations; ++j) {
+ for (int j = 0; j < FLAGS_iterations; ++j) {
layers[i]->Backward(top_vecs[i], bottom_need_backward[i],
&bottom_vecs[i]);
}