From ff19d5f5c010dd8d6bfcf768b4fe27d0458f17df Mon Sep 17 00:00:00 2001 From: J Yegerlehner Date: Fri, 3 Apr 2015 16:11:23 -0500 Subject: Add signal handler and early exit/snapshot to Solver. Add signal handler and early exit/snapshot to Solver. Add signal handler and early exit/snapshot to Solver. Also check for exit and snapshot when testing. Skip running test after early exit. Fix more lint. Rebase on master. Finish rebase on master. Fixups per review comments. Redress review comments. Lint. Correct error message wording. --- src/caffe/solver.cpp | 70 ++++++++++++++++++++--- src/caffe/util/signal_handler.cpp | 115 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+), 8 deletions(-) create mode 100644 src/caffe/util/signal_handler.cpp (limited to 'src') diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 9348e11c..394ec3b3 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -17,15 +17,31 @@ namespace caffe { +template +void Solver::SetActionFunction(ActionCallback func) { + action_request_function_ = func; +} + +template +SolverAction::Enum Solver::GetRequestedAction() { + if (action_request_function_) { + // If the external request function has been set, call it. + return action_request_function_(); + } + return SolverAction::NONE; +} + template Solver::Solver(const SolverParameter& param, const Solver* root_solver) - : net_(), callbacks_(), root_solver_(root_solver) { + : net_(), callbacks_(), root_solver_(root_solver), + requested_early_exit_(false) { Init(param); } template Solver::Solver(const string& param_file, const Solver* root_solver) - : net_(), callbacks_(), root_solver_(root_solver) { + : net_(), callbacks_(), root_solver_(root_solver), + requested_early_exit_(false) { SolverParameter param; ReadProtoFromTextFileOrDie(param_file, ¶m); Init(param); @@ -195,6 +211,10 @@ void Solver::Step(int iters) { && (iter_ > 0 || param_.test_initialization()) && Caffe::root_solver()) { TestAll(); + if (requested_early_exit_) { + // Break out of the while loop because stop was requested while testing. + break; + } } for (int i = 0; i < callbacks_.size(); ++i) { @@ -250,12 +270,20 @@ void Solver::Step(int iters) { // the number of times the weights have been updated. ++iter_; + SolverAction::Enum request = GetRequestedAction(); + // Save a snapshot if needed. - if (param_.snapshot() - && iter_ % param_.snapshot() == 0 - && Caffe::root_solver()) { + if ((param_.snapshot() + && iter_ % param_.snapshot() == 0 + && Caffe::root_solver()) || + (request == SolverAction::SNAPSHOT)) { Snapshot(); } + if (SolverAction::STOP == request) { + requested_early_exit_ = true; + // Break out of training loop. + break; + } } } @@ -265,6 +293,9 @@ void Solver::Solve(const char* resume_file) { LOG(INFO) << "Solving " << net_->name(); LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); + // Initialize to false every time we start solving. + requested_early_exit_ = false; + if (resume_file) { LOG(INFO) << "Restoring previous solver status from " << resume_file; Restore(resume_file); @@ -279,6 +310,10 @@ void Solver::Solve(const char* resume_file) { && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) { Snapshot(); } + if (requested_early_exit_) { + LOG(INFO) << "Optimization stopped early."; + return; + } // After the optimization is done, run an additional train and test pass to // display the train and test loss/outputs if appropriate (based on the // display and test_interval settings, respectively). Unlike in the rest of @@ -296,10 +331,11 @@ void Solver::Solve(const char* resume_file) { LOG(INFO) << "Optimization Done."; } - template void Solver::TestAll() { - for (int test_net_id = 0; test_net_id < test_nets_.size(); ++test_net_id) { + for (int test_net_id = 0; + test_net_id < test_nets_.size() && !requested_early_exit_; + ++test_net_id) { Test(test_net_id); } } @@ -317,6 +353,21 @@ void Solver::Test(const int test_net_id) { const shared_ptr >& test_net = test_nets_[test_net_id]; Dtype loss = 0; for (int i = 0; i < param_.test_iter(test_net_id); ++i) { + SolverAction::Enum request = GetRequestedAction(); + // Check to see if stoppage of testing/training has been requested. + while (request != SolverAction::NONE) { + if (SolverAction::SNAPSHOT == request) { + Snapshot(); + } else if (SolverAction::STOP == request) { + requested_early_exit_ = true; + } + request = GetRequestedAction(); + } + if (requested_early_exit_) { + // break out of test loop. + break; + } + Dtype iter_loss; const vector*>& result = test_net->Forward(bottom_vec, &iter_loss); @@ -341,6 +392,10 @@ void Solver::Test(const int test_net_id) { } } } + if (requested_early_exit_) { + LOG(INFO) << "Test interrupted."; + return; + } if (param_.test_compute_loss()) { loss /= param_.test_iter(test_net_id); LOG(INFO) << "Test loss: " << loss; @@ -361,7 +416,6 @@ void Solver::Test(const int test_net_id) { } } - template void Solver::Snapshot() { CHECK(Caffe::root_solver()); diff --git a/src/caffe/util/signal_handler.cpp b/src/caffe/util/signal_handler.cpp new file mode 100644 index 00000000..5d764ec5 --- /dev/null +++ b/src/caffe/util/signal_handler.cpp @@ -0,0 +1,115 @@ +#include +#include + +#include +#include + +#include "caffe/util/signal_handler.h" + +namespace { + static volatile sig_atomic_t got_sigint = false; + static volatile sig_atomic_t got_sighup = false; + static bool already_hooked_up = false; + + void handle_signal(int signal) { + switch (signal) { + case SIGHUP: + got_sighup = true; + break; + case SIGINT: + got_sigint = true; + break; + } + } + + void HookupHandler() { + if (already_hooked_up) { + LOG(FATAL) << "Tried to hookup signal handlers more than once."; + } + already_hooked_up = true; + + struct sigaction sa; + // Setup the handler + sa.sa_handler = &handle_signal; + // Restart the system call, if at all possible + sa.sa_flags = SA_RESTART; + // Block every signal during the handler + sigfillset(&sa.sa_mask); + // Intercept SIGHUP and SIGINT + if (sigaction(SIGHUP, &sa, NULL) == -1) { + LOG(FATAL) << "Cannot install SIGHUP handler."; + } + if (sigaction(SIGINT, &sa, NULL) == -1) { + LOG(FATAL) << "Cannot install SIGINT handler."; + } + } + + // Set the signal handlers to the default. + void UnhookHandler() { + if (already_hooked_up) { + struct sigaction sa; + // Setup the sighub handler + sa.sa_handler = SIG_DFL; + // Restart the system call, if at all possible + sa.sa_flags = SA_RESTART; + // Block every signal during the handler + sigfillset(&sa.sa_mask); + // Intercept SIGHUP and SIGINT + if (sigaction(SIGHUP, &sa, NULL) == -1) { + LOG(FATAL) << "Cannot uninstall SIGHUP handler."; + } + if (sigaction(SIGINT, &sa, NULL) == -1) { + LOG(FATAL) << "Cannot uninstall SIGINT handler."; + } + + already_hooked_up = false; + } + } + + // Return true iff a SIGINT has been received since the last time this + // function was called. + bool GotSIGINT() { + bool result = got_sigint; + got_sigint = false; + return result; + } + + // Return true iff a SIGHUP has been received since the last time this + // function was called. + bool GotSIGHUP() { + bool result = got_sighup; + got_sighup = false; + return result; + } +} // namespace + +namespace caffe { + +SignalHandler::SignalHandler(SolverAction::Enum SIGINT_action, + SolverAction::Enum SIGHUP_action): + SIGINT_action_(SIGINT_action), + SIGHUP_action_(SIGHUP_action) { + HookupHandler(); +} + +SignalHandler::~SignalHandler() { + UnhookHandler(); +} + +SolverAction::Enum SignalHandler::CheckForSignals() const { + if (GotSIGHUP()) { + return SIGHUP_action_; + } + if (GotSIGINT()) { + return SIGINT_action_; + } + return SolverAction::NONE; +} + +// Return the function that the solver can use to find out if a snapshot or +// early exit is being requested. +ActionCallback SignalHandler::GetActionFunction() { + return boost::bind(&SignalHandler::CheckForSignals, this); +} + +} // namespace caffe -- cgit v1.2.3