summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/caffe/solver.hpp3
-rw-r--r--tools/train_net.cpp30
2 files changed, 31 insertions, 2 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
index 03c65580..9d5481cc 100644
--- a/include/caffe/solver.hpp
+++ b/include/caffe/solver.hpp
@@ -120,8 +120,11 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) {
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}
+ return (Solver<Dtype>*) NULL;
}
+template Solver<float>* GetSolver(const SolverParameter& param);
+template Solver<double>* GetSolver(const SolverParameter& param);
} // namespace caffe
diff --git a/tools/train_net.cpp b/tools/train_net.cpp
index 11767591..2a2a522d 100644
--- a/tools/train_net.cpp
+++ b/tools/train_net.cpp
@@ -1,10 +1,36 @@
+// Copyright 2014 BVLC and contributors.
+//
+// This is a simple script that allows one to quickly train a network whose
+// parameters are specified by text format protocol buffers.
+// Usage:
+// train_net net_proto_file solver_proto_file [resume_point_file]
+
+#include <cstring>
+
#include "caffe/caffe.hpp"
using namespace caffe; // NOLINT(build/namespaces)
int main(int argc, char** argv) {
+ ::google::InitGoogleLogging(argv[0]);
+ if (argc < 2 || argc > 3) {
+ LOG(ERROR) << "Usage: train_net solver_proto_file [resume_point_file]";
+ return 1;
+ }
+
+ SolverParameter solver_param;
+ ReadProtoFromTextFileOrDie(argv[1], &solver_param);
+
+ LOG(INFO) << "Starting Optimization";
+ shared_ptr<Solver<float> > solver =
+ (shared_ptr<Solver<float> >) GetSolver<float>(solver_param);
+ if (argc == 3) {
+ LOG(INFO) << "Resuming from " << argv[2];
+ solver->Solve(argv[2]);
+ } else {
+ solver->Solve();
+ }
+ LOG(INFO) << "Optimization Done.";
- LOG(FATAL) << "Deprecated. Use caffe train --solver=... "
- "[--snapshot=...] instead.";
return 0;
}