summaryrefslogtreecommitdiff
path: root/tools/train_net.cpp
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-02-24 22:45:29 -0800
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-02-26 12:37:44 -0800
commit4f1cdeb4ef35b9f85646a9b624188fd761644f1d (patch)
tree0727acbf45b45c7f3e05120868e62cb3a2faf0c5 /tools/train_net.cpp
parent650b7b16dbbf0ee7b3a7362fce0821315d608d54 (diff)
downloadcaffeonacl-4f1cdeb4ef35b9f85646a9b624188fd761644f1d.tar.gz
caffeonacl-4f1cdeb4ef35b9f85646a9b624188fd761644f1d.tar.bz2
caffeonacl-4f1cdeb4ef35b9f85646a9b624188fd761644f1d.zip
Make tools/ for core binaries, stow scripts/ in tools/extra
Collect core Caffe tools like train_net, device_query, etc. together in tools/ and include helper scripts under tools/extra.
Diffstat (limited to 'tools/train_net.cpp')
-rw-r--r--tools/train_net.cpp37
1 files changed, 37 insertions, 0 deletions
diff --git a/tools/train_net.cpp b/tools/train_net.cpp
new file mode 100644
index 00000000..ce62616b
--- /dev/null
+++ b/tools/train_net.cpp
@@ -0,0 +1,37 @@
+// Copyright 2013 Yangqing Jia
+//
+// This is a simple script that allows one to quickly train a network whose
+// parameters are specified by text format protocol buffers.
+// Usage:
+// train_net net_proto_file solver_proto_file [resume_point_file]
+
+#include <cuda_runtime.h>
+
+#include <cstring>
+
+#include "caffe/caffe.hpp"
+
+using namespace caffe;
+
+int main(int argc, char** argv) {
+ ::google::InitGoogleLogging(argv[0]);
+ if (argc < 2) {
+ LOG(ERROR) << "Usage: train_net solver_proto_file [resume_point_file]";
+ return 0;
+ }
+
+ SolverParameter solver_param;
+ ReadProtoFromTextFile(argv[1], &solver_param);
+
+ LOG(INFO) << "Starting Optimization";
+ SGDSolver<float> solver(solver_param);
+ if (argc == 3) {
+ LOG(INFO) << "Resuming from " << argv[2];
+ solver.Solve(argv[2]);
+ } else {
+ solver.Solve();
+ }
+ LOG(INFO) << "Optimization Done.";
+
+ return 0;
+}