summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJ Yegerlehner <jyegerlehner@yahoo.com>2015-04-03 16:11:23 -0500
committerJ Yegerlehner <jyegerlehner@yahoo.com>2015-08-22 12:51:55 -0500
commitff19d5f5c010dd8d6bfcf768b4fe27d0458f17df (patch)
tree721e4a100606b44441db4e563b027ed590c3a146
parentc6b9f580540f5aa16d05d8e283f9e0050dda2fb5 (diff)
downloadcaffeonacl-ff19d5f5c010dd8d6bfcf768b4fe27d0458f17df.tar.gz
caffeonacl-ff19d5f5c010dd8d6bfcf768b4fe27d0458f17df.tar.bz2
caffeonacl-ff19d5f5c010dd8d6bfcf768b4fe27d0458f17df.zip
Add signal handler and early exit/snapshot to Solver.
Add signal handler and early exit/snapshot to Solver. Add signal handler and early exit/snapshot to Solver. Also check for exit and snapshot when testing. Skip running test after early exit. Fix more lint. Rebase on master. Finish rebase on master. Fixups per review comments. Redress review comments. Lint. Correct error message wording.
-rw-r--r--include/caffe/solver.hpp37
-rw-r--r--include/caffe/util/signal_handler.h24
-rw-r--r--src/caffe/solver.cpp70
-rw-r--r--src/caffe/util/signal_handler.cpp115
-rw-r--r--tools/caffe.cpp32
5 files changed, 268 insertions, 10 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
index ab12ef1b..aba3e036 100644
--- a/include/caffe/solver.hpp
+++ b/include/caffe/solver.hpp
@@ -1,6 +1,6 @@
#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
#define CAFFE_OPTIMIZATION_SOLVER_HPP_
-
+#include <boost/function.hpp>
#include <string>
#include <vector>
@@ -9,6 +9,28 @@
namespace caffe {
/**
+ * @brief Enumeration of actions that a client of the Solver may request by
+ * implementing the Solver's action request function, which a
+ * a client may optionally provide in order to request early termination
+ * or saving a snapshot without exiting. In the executable caffe, this
+ * mechanism is used to allow the snapshot to be saved when stopping
+ * execution with a SIGINT (Ctrl-C).
+ */
+ namespace SolverAction {
+ enum Enum {
+ NONE = 0, // Take no special action.
+ STOP = 1, // Stop training. snapshot_after_train controls whether a
+ // snapshot is created.
+ SNAPSHOT = 2 // Take a snapshot, and keep training.
+ };
+ }
+
+/**
+ * @brief Type of a function that returns a Solver Action enumeration.
+ */
+typedef boost::function<SolverAction::Enum()> ActionCallback;
+
+/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ApplyUpdate to compute a parameter update
@@ -23,6 +45,12 @@ class Solver {
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();
+
+ // Client of the Solver optionally may call this in order to set the function
+ // that the solver uses to see what action it should take (e.g. snapshot or
+ // exit training early).
+ void SetActionFunction(ActionCallback func);
+ SolverAction::Enum GetRequestedAction();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
@@ -84,6 +112,13 @@ class Solver {
// in data parallelism
const Solver* const root_solver_;
+ // A function that can be set by a client of the Solver to provide indication
+ // that it wants a snapshot saved and/or to exit early.
+ ActionCallback action_request_function_;
+
+ // True iff a request to stop early was received.
+ bool requested_early_exit_;
+
DISABLE_COPY_AND_ASSIGN(Solver);
};
diff --git a/include/caffe/util/signal_handler.h b/include/caffe/util/signal_handler.h
new file mode 100644
index 00000000..fb84c65b
--- /dev/null
+++ b/include/caffe/util/signal_handler.h
@@ -0,0 +1,24 @@
+#ifndef INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
+#define INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
+
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/solver.hpp"
+
+namespace caffe {
+
+class SignalHandler {
+ public:
+ // Contructor. Specify what action to take when a signal is received.
+ SignalHandler(SolverAction::Enum SIGINT_action,
+ SolverAction::Enum SIGHUP_action);
+ ~SignalHandler();
+ ActionCallback GetActionFunction();
+ private:
+ SolverAction::Enum CheckForSignals() const;
+ SolverAction::Enum SIGINT_action_;
+ SolverAction::Enum SIGHUP_action_;
+};
+
+} // namespace caffe
+
+#endif // INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 9348e11c..394ec3b3 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -17,15 +17,31 @@
namespace caffe {
+template<typename Dtype>
+void Solver<Dtype>::SetActionFunction(ActionCallback func) {
+ action_request_function_ = func;
+}
+
+template<typename Dtype>
+SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
+ if (action_request_function_) {
+ // If the external request function has been set, call it.
+ return action_request_function_();
+ }
+ return SolverAction::NONE;
+}
+
template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
- : net_(), callbacks_(), root_solver_(root_solver) {
+ : net_(), callbacks_(), root_solver_(root_solver),
+ requested_early_exit_(false) {
Init(param);
}
template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
- : net_(), callbacks_(), root_solver_(root_solver) {
+ : net_(), callbacks_(), root_solver_(root_solver),
+ requested_early_exit_(false) {
SolverParameter param;
ReadProtoFromTextFileOrDie(param_file, &param);
Init(param);
@@ -195,6 +211,10 @@ void Solver<Dtype>::Step(int iters) {
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
TestAll();
+ if (requested_early_exit_) {
+ // Break out of the while loop because stop was requested while testing.
+ break;
+ }
}
for (int i = 0; i < callbacks_.size(); ++i) {
@@ -250,12 +270,20 @@ void Solver<Dtype>::Step(int iters) {
// the number of times the weights have been updated.
++iter_;
+ SolverAction::Enum request = GetRequestedAction();
+
// Save a snapshot if needed.
- if (param_.snapshot()
- && iter_ % param_.snapshot() == 0
- && Caffe::root_solver()) {
+ if ((param_.snapshot()
+ && iter_ % param_.snapshot() == 0
+ && Caffe::root_solver()) ||
+ (request == SolverAction::SNAPSHOT)) {
Snapshot();
}
+ if (SolverAction::STOP == request) {
+ requested_early_exit_ = true;
+ // Break out of training loop.
+ break;
+ }
}
}
@@ -265,6 +293,9 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
+ // Initialize to false every time we start solving.
+ requested_early_exit_ = false;
+
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
@@ -279,6 +310,10 @@ void Solver<Dtype>::Solve(const char* resume_file) {
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot();
}
+ if (requested_early_exit_) {
+ LOG(INFO) << "Optimization stopped early.";
+ return;
+ }
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
@@ -296,10 +331,11 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Optimization Done.";
}
-
template <typename Dtype>
void Solver<Dtype>::TestAll() {
- for (int test_net_id = 0; test_net_id < test_nets_.size(); ++test_net_id) {
+ for (int test_net_id = 0;
+ test_net_id < test_nets_.size() && !requested_early_exit_;
+ ++test_net_id) {
Test(test_net_id);
}
}
@@ -317,6 +353,21 @@ void Solver<Dtype>::Test(const int test_net_id) {
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
+ SolverAction::Enum request = GetRequestedAction();
+ // Check to see if stoppage of testing/training has been requested.
+ while (request != SolverAction::NONE) {
+ if (SolverAction::SNAPSHOT == request) {
+ Snapshot();
+ } else if (SolverAction::STOP == request) {
+ requested_early_exit_ = true;
+ }
+ request = GetRequestedAction();
+ }
+ if (requested_early_exit_) {
+ // break out of test loop.
+ break;
+ }
+
Dtype iter_loss;
const vector<Blob<Dtype>*>& result =
test_net->Forward(bottom_vec, &iter_loss);
@@ -341,6 +392,10 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}
}
+ if (requested_early_exit_) {
+ LOG(INFO) << "Test interrupted.";
+ return;
+ }
if (param_.test_compute_loss()) {
loss /= param_.test_iter(test_net_id);
LOG(INFO) << "Test loss: " << loss;
@@ -361,7 +416,6 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}
-
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
CHECK(Caffe::root_solver());
diff --git a/src/caffe/util/signal_handler.cpp b/src/caffe/util/signal_handler.cpp
new file mode 100644
index 00000000..5d764ec5
--- /dev/null
+++ b/src/caffe/util/signal_handler.cpp
@@ -0,0 +1,115 @@
+#include <boost/bind.hpp>
+#include <glog/logging.h>
+
+#include <signal.h>
+#include <csignal>
+
+#include "caffe/util/signal_handler.h"
+
+namespace {
+ static volatile sig_atomic_t got_sigint = false;
+ static volatile sig_atomic_t got_sighup = false;
+ static bool already_hooked_up = false;
+
+ void handle_signal(int signal) {
+ switch (signal) {
+ case SIGHUP:
+ got_sighup = true;
+ break;
+ case SIGINT:
+ got_sigint = true;
+ break;
+ }
+ }
+
+ void HookupHandler() {
+ if (already_hooked_up) {
+ LOG(FATAL) << "Tried to hookup signal handlers more than once.";
+ }
+ already_hooked_up = true;
+
+ struct sigaction sa;
+ // Setup the handler
+ sa.sa_handler = &handle_signal;
+ // Restart the system call, if at all possible
+ sa.sa_flags = SA_RESTART;
+ // Block every signal during the handler
+ sigfillset(&sa.sa_mask);
+ // Intercept SIGHUP and SIGINT
+ if (sigaction(SIGHUP, &sa, NULL) == -1) {
+ LOG(FATAL) << "Cannot install SIGHUP handler.";
+ }
+ if (sigaction(SIGINT, &sa, NULL) == -1) {
+ LOG(FATAL) << "Cannot install SIGINT handler.";
+ }
+ }
+
+ // Set the signal handlers to the default.
+ void UnhookHandler() {
+ if (already_hooked_up) {
+ struct sigaction sa;
+ // Setup the sighub handler
+ sa.sa_handler = SIG_DFL;
+ // Restart the system call, if at all possible
+ sa.sa_flags = SA_RESTART;
+ // Block every signal during the handler
+ sigfillset(&sa.sa_mask);
+ // Intercept SIGHUP and SIGINT
+ if (sigaction(SIGHUP, &sa, NULL) == -1) {
+ LOG(FATAL) << "Cannot uninstall SIGHUP handler.";
+ }
+ if (sigaction(SIGINT, &sa, NULL) == -1) {
+ LOG(FATAL) << "Cannot uninstall SIGINT handler.";
+ }
+
+ already_hooked_up = false;
+ }
+ }
+
+ // Return true iff a SIGINT has been received since the last time this
+ // function was called.
+ bool GotSIGINT() {
+ bool result = got_sigint;
+ got_sigint = false;
+ return result;
+ }
+
+ // Return true iff a SIGHUP has been received since the last time this
+ // function was called.
+ bool GotSIGHUP() {
+ bool result = got_sighup;
+ got_sighup = false;
+ return result;
+ }
+} // namespace
+
+namespace caffe {
+
+SignalHandler::SignalHandler(SolverAction::Enum SIGINT_action,
+ SolverAction::Enum SIGHUP_action):
+ SIGINT_action_(SIGINT_action),
+ SIGHUP_action_(SIGHUP_action) {
+ HookupHandler();
+}
+
+SignalHandler::~SignalHandler() {
+ UnhookHandler();
+}
+
+SolverAction::Enum SignalHandler::CheckForSignals() const {
+ if (GotSIGHUP()) {
+ return SIGHUP_action_;
+ }
+ if (GotSIGINT()) {
+ return SIGINT_action_;
+ }
+ return SolverAction::NONE;
+}
+
+// Return the function that the solver can use to find out if a snapshot or
+// early exit is being requested.
+ActionCallback SignalHandler::GetActionFunction() {
+ return boost::bind(&SignalHandler::CheckForSignals, this);
+}
+
+} // namespace caffe
diff --git a/tools/caffe.cpp b/tools/caffe.cpp
index 9f31b37a..ff63860a 100644
--- a/tools/caffe.cpp
+++ b/tools/caffe.cpp
@@ -12,6 +12,7 @@ namespace bp = boost::python;
#include "boost/algorithm/string.hpp"
#include "caffe/caffe.hpp"
+#include "caffe/util/signal_handler.h"
using caffe::Blob;
using caffe::Caffe;
@@ -39,6 +40,12 @@ DEFINE_string(weights, "",
"separated by ','. Cannot be set simultaneously with snapshot.");
DEFINE_int32(iterations, 50,
"The number of iterations to run.");
+DEFINE_string(sigint_effect, "stop",
+ "Optional; action to take when a SIGINT signal is received: "
+ "snapshot, stop or none.");
+DEFINE_string(sighup_effect, "snapshot",
+ "Optional; action to take when a SIGHUP signal is received: "
+ "snapshot, stop or none.");
// A simple registry for caffe commands.
typedef int (*BrewFunction)();
@@ -126,6 +133,22 @@ void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
}
}
+// Translate the signal effect the user specified on the command-line to the
+// corresponding enumeration.
+caffe::SolverAction::Enum GetRequestedAction(
+ const std::string& flag_value) {
+ if (flag_value == "stop") {
+ return caffe::SolverAction::STOP;
+ }
+ if (flag_value == "snapshot") {
+ return caffe::SolverAction::SNAPSHOT;
+ }
+ if (flag_value == "none") {
+ return caffe::SolverAction::NONE;
+ }
+ LOG(FATAL) << "Invalid signal effect \""<< flag_value << "\" was specified";
+}
+
// Train / Finetune a model.
int train() {
CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train.";
@@ -165,7 +188,14 @@ int train() {
Caffe::set_solver_count(gpus.size());
}
- shared_ptr<Solver<float> > solver(caffe::GetSolver<float>(solver_param));
+ caffe::SignalHandler signal_handler(
+ GetRequestedAction(FLAGS_sigint_effect),
+ GetRequestedAction(FLAGS_sighup_effect));
+
+ shared_ptr<caffe::Solver<float> >
+ solver(caffe::GetSolver<float>(solver_param));
+
+ solver->SetActionFunction(signal_handler.GetActionFunction());
if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;