diff options
author | qipeng <pengrobertqi@163.com> | 2014-07-20 08:57:33 -0700 |
---|---|---|
committer | Jeff Donahue <jeff.donahue@gmail.com> | 2014-09-01 11:33:41 -0700 |
commit | ed8b1da57fbbadb611d98372671fafd77d863234 (patch) | |
tree | 9f4fbf84c478493947fc5859ea3212e91a2dcf3b | |
parent | 8a9c268bd53767365fa0760c167bdcd0158a56f3 (diff) | |
download | caffeonacl-ed8b1da57fbbadb611d98372671fafd77d863234.tar.gz caffeonacl-ed8b1da57fbbadb611d98372671fafd77d863234.tar.bz2 caffeonacl-ed8b1da57fbbadb611d98372671fafd77d863234.zip |
converted pointers to shared_ptr
-rw-r--r-- | include/caffe/solver.hpp | 3 | ||||
-rw-r--r-- | tools/train_net.cpp | 30 |
2 files changed, 31 insertions, 2 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 03c65580..9d5481cc 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -120,8 +120,11 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) { default: LOG(FATAL) << "Unknown SolverType: " << type; } + return (Solver<Dtype>*) NULL; } +template Solver<float>* GetSolver(const SolverParameter& param); +template Solver<double>* GetSolver(const SolverParameter& param); } // namespace caffe diff --git a/tools/train_net.cpp b/tools/train_net.cpp index 11767591..2a2a522d 100644 --- a/tools/train_net.cpp +++ b/tools/train_net.cpp @@ -1,10 +1,36 @@ +// Copyright 2014 BVLC and contributors. +// +// This is a simple script that allows one to quickly train a network whose +// parameters are specified by text format protocol buffers. +// Usage: +// train_net net_proto_file solver_proto_file [resume_point_file] + +#include <cstring> + #include "caffe/caffe.hpp" using namespace caffe; // NOLINT(build/namespaces) int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + if (argc < 2 || argc > 3) { + LOG(ERROR) << "Usage: train_net solver_proto_file [resume_point_file]"; + return 1; + } + + SolverParameter solver_param; + ReadProtoFromTextFileOrDie(argv[1], &solver_param); + + LOG(INFO) << "Starting Optimization"; + shared_ptr<Solver<float> > solver = + (shared_ptr<Solver<float> >) GetSolver<float>(solver_param); + if (argc == 3) { + LOG(INFO) << "Resuming from " << argv[2]; + solver->Solve(argv[2]); + } else { + solver->Solve(); + } + LOG(INFO) << "Optimization Done."; - LOG(FATAL) << "Deprecated. Use caffe train --solver=... " - "[--snapshot=...] instead."; return 0; } |