summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJ Yegerlehner <jyegerlehner@yahoo.com>2015-04-03 16:11:23 -0500
committerJ Yegerlehner <jyegerlehner@yahoo.com>2015-08-22 12:51:55 -0500
commitff19d5f5c010dd8d6bfcf768b4fe27d0458f17df (patch)
tree721e4a100606b44441db4e563b027ed590c3a146 /src
parentc6b9f580540f5aa16d05d8e283f9e0050dda2fb5 (diff)
downloadcaffeonacl-ff19d5f5c010dd8d6bfcf768b4fe27d0458f17df.tar.gz
caffeonacl-ff19d5f5c010dd8d6bfcf768b4fe27d0458f17df.tar.bz2
caffeonacl-ff19d5f5c010dd8d6bfcf768b4fe27d0458f17df.zip
Add signal handler and early exit/snapshot to Solver.
Add signal handler and early exit/snapshot to Solver. Add signal handler and early exit/snapshot to Solver. Also check for exit and snapshot when testing. Skip running test after early exit. Fix more lint. Rebase on master. Finish rebase on master. Fixups per review comments. Redress review comments. Lint. Correct error message wording.
Diffstat (limited to 'src')
-rw-r--r--src/caffe/solver.cpp70
-rw-r--r--src/caffe/util/signal_handler.cpp115
2 files changed, 177 insertions, 8 deletions
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 9348e11c..394ec3b3 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -17,15 +17,31 @@
namespace caffe {
+template<typename Dtype>
+void Solver<Dtype>::SetActionFunction(ActionCallback func) {
+ action_request_function_ = func;
+}
+
+template<typename Dtype>
+SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
+ if (action_request_function_) {
+ // If the external request function has been set, call it.
+ return action_request_function_();
+ }
+ return SolverAction::NONE;
+}
+
template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
- : net_(), callbacks_(), root_solver_(root_solver) {
+ : net_(), callbacks_(), root_solver_(root_solver),
+ requested_early_exit_(false) {
Init(param);
}
template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
- : net_(), callbacks_(), root_solver_(root_solver) {
+ : net_(), callbacks_(), root_solver_(root_solver),
+ requested_early_exit_(false) {
SolverParameter param;
ReadProtoFromTextFileOrDie(param_file, &param);
Init(param);
@@ -195,6 +211,10 @@ void Solver<Dtype>::Step(int iters) {
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
TestAll();
+ if (requested_early_exit_) {
+ // Break out of the while loop because stop was requested while testing.
+ break;
+ }
}
for (int i = 0; i < callbacks_.size(); ++i) {
@@ -250,12 +270,20 @@ void Solver<Dtype>::Step(int iters) {
// the number of times the weights have been updated.
++iter_;
+ SolverAction::Enum request = GetRequestedAction();
+
// Save a snapshot if needed.
- if (param_.snapshot()
- && iter_ % param_.snapshot() == 0
- && Caffe::root_solver()) {
+ if ((param_.snapshot()
+ && iter_ % param_.snapshot() == 0
+ && Caffe::root_solver()) ||
+ (request == SolverAction::SNAPSHOT)) {
Snapshot();
}
+ if (SolverAction::STOP == request) {
+ requested_early_exit_ = true;
+ // Break out of training loop.
+ break;
+ }
}
}
@@ -265,6 +293,9 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
+ // Initialize to false every time we start solving.
+ requested_early_exit_ = false;
+
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
@@ -279,6 +310,10 @@ void Solver<Dtype>::Solve(const char* resume_file) {
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot();
}
+ if (requested_early_exit_) {
+ LOG(INFO) << "Optimization stopped early.";
+ return;
+ }
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
@@ -296,10 +331,11 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Optimization Done.";
}
-
template <typename Dtype>
void Solver<Dtype>::TestAll() {
- for (int test_net_id = 0; test_net_id < test_nets_.size(); ++test_net_id) {
+ for (int test_net_id = 0;
+ test_net_id < test_nets_.size() && !requested_early_exit_;
+ ++test_net_id) {
Test(test_net_id);
}
}
@@ -317,6 +353,21 @@ void Solver<Dtype>::Test(const int test_net_id) {
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
+ SolverAction::Enum request = GetRequestedAction();
+ // Check to see if stoppage of testing/training has been requested.
+ while (request != SolverAction::NONE) {
+ if (SolverAction::SNAPSHOT == request) {
+ Snapshot();
+ } else if (SolverAction::STOP == request) {
+ requested_early_exit_ = true;
+ }
+ request = GetRequestedAction();
+ }
+ if (requested_early_exit_) {
+ // break out of test loop.
+ break;
+ }
+
Dtype iter_loss;
const vector<Blob<Dtype>*>& result =
test_net->Forward(bottom_vec, &iter_loss);
@@ -341,6 +392,10 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}
}
+ if (requested_early_exit_) {
+ LOG(INFO) << "Test interrupted.";
+ return;
+ }
if (param_.test_compute_loss()) {
loss /= param_.test_iter(test_net_id);
LOG(INFO) << "Test loss: " << loss;
@@ -361,7 +416,6 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}
-
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
CHECK(Caffe::root_solver());
diff --git a/src/caffe/util/signal_handler.cpp b/src/caffe/util/signal_handler.cpp
new file mode 100644
index 00000000..5d764ec5
--- /dev/null
+++ b/src/caffe/util/signal_handler.cpp
@@ -0,0 +1,115 @@
+#include <boost/bind.hpp>
+#include <glog/logging.h>
+
+#include <signal.h>
+#include <csignal>
+
+#include "caffe/util/signal_handler.h"
+
+namespace {
+ static volatile sig_atomic_t got_sigint = false;
+ static volatile sig_atomic_t got_sighup = false;
+ static bool already_hooked_up = false;
+
+ void handle_signal(int signal) {
+ switch (signal) {
+ case SIGHUP:
+ got_sighup = true;
+ break;
+ case SIGINT:
+ got_sigint = true;
+ break;
+ }
+ }
+
+ void HookupHandler() {
+ if (already_hooked_up) {
+ LOG(FATAL) << "Tried to hookup signal handlers more than once.";
+ }
+ already_hooked_up = true;
+
+ struct sigaction sa;
+ // Setup the handler
+ sa.sa_handler = &handle_signal;
+ // Restart the system call, if at all possible
+ sa.sa_flags = SA_RESTART;
+ // Block every signal during the handler
+ sigfillset(&sa.sa_mask);
+ // Intercept SIGHUP and SIGINT
+ if (sigaction(SIGHUP, &sa, NULL) == -1) {
+ LOG(FATAL) << "Cannot install SIGHUP handler.";
+ }
+ if (sigaction(SIGINT, &sa, NULL) == -1) {
+ LOG(FATAL) << "Cannot install SIGINT handler.";
+ }
+ }
+
+ // Set the signal handlers to the default.
+ void UnhookHandler() {
+ if (already_hooked_up) {
+ struct sigaction sa;
+ // Setup the sighub handler
+ sa.sa_handler = SIG_DFL;
+ // Restart the system call, if at all possible
+ sa.sa_flags = SA_RESTART;
+ // Block every signal during the handler
+ sigfillset(&sa.sa_mask);
+ // Intercept SIGHUP and SIGINT
+ if (sigaction(SIGHUP, &sa, NULL) == -1) {
+ LOG(FATAL) << "Cannot uninstall SIGHUP handler.";
+ }
+ if (sigaction(SIGINT, &sa, NULL) == -1) {
+ LOG(FATAL) << "Cannot uninstall SIGINT handler.";
+ }
+
+ already_hooked_up = false;
+ }
+ }
+
+ // Return true iff a SIGINT has been received since the last time this
+ // function was called.
+ bool GotSIGINT() {
+ bool result = got_sigint;
+ got_sigint = false;
+ return result;
+ }
+
+ // Return true iff a SIGHUP has been received since the last time this
+ // function was called.
+ bool GotSIGHUP() {
+ bool result = got_sighup;
+ got_sighup = false;
+ return result;
+ }
+} // namespace
+
+namespace caffe {
+
+SignalHandler::SignalHandler(SolverAction::Enum SIGINT_action,
+ SolverAction::Enum SIGHUP_action):
+ SIGINT_action_(SIGINT_action),
+ SIGHUP_action_(SIGHUP_action) {
+ HookupHandler();
+}
+
+SignalHandler::~SignalHandler() {
+ UnhookHandler();
+}
+
+SolverAction::Enum SignalHandler::CheckForSignals() const {
+ if (GotSIGHUP()) {
+ return SIGHUP_action_;
+ }
+ if (GotSIGINT()) {
+ return SIGINT_action_;
+ }
+ return SolverAction::NONE;
+}
+
+// Return the function that the solver can use to find out if a snapshot or
+// early exit is being requested.
+ActionCallback SignalHandler::GetActionFunction() {
+ return boost::bind(&SignalHandler::CheckForSignals, this);
+}
+
+} // namespace caffe