summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/caffe/solver.hpp37
-rw-r--r--include/caffe/util/signal_handler.h24
-rw-r--r--src/caffe/solver.cpp70
-rw-r--r--src/caffe/util/signal_handler.cpp115
-rw-r--r--tools/caffe.cpp32
5 files changed, 268 insertions, 10 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
index ab12ef1b..aba3e036 100644
--- a/include/caffe/solver.hpp
+++ b/include/caffe/solver.hpp
@@ -1,6 +1,6 @@
#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
#define CAFFE_OPTIMIZATION_SOLVER_HPP_
-
+#include <boost/function.hpp>
#include <string>
#include <vector>
@@ -9,6 +9,28 @@
namespace caffe {
/**
+ * @brief Enumeration of actions that a client of the Solver may request by
+ * implementing the Solver's action request function, which a
+ * a client may optionally provide in order to request early termination
+ * or saving a snapshot without exiting. In the executable caffe, this
+ * mechanism is used to allow the snapshot to be saved when stopping
+ * execution with a SIGINT (Ctrl-C).
+ */
+ namespace SolverAction {
+ enum Enum {
+ NONE = 0, // Take no special action.
+ STOP = 1, // Stop training. snapshot_after_train controls whether a
+ // snapshot is created.
+ SNAPSHOT = 2 // Take a snapshot, and keep training.
+ };
+ }
+
+/**
+ * @brief Type of a function that returns a Solver Action enumeration.
+ */
+typedef boost::function<SolverAction::Enum()> ActionCallback;
+
+/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ApplyUpdate to compute a parameter update
@@ -23,6 +45,12 @@ class Solver {
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();
+
+ // Client of the Solver optionally may call this in order to set the function
+ // that the solver uses to see what action it should take (e.g. snapshot or
+ // exit training early).
+ void SetActionFunction(ActionCallback func);
+ SolverAction::Enum GetRequestedAction();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
@@ -84,6 +112,13 @@ class Solver {
// in data parallelism
const Solver* const root_solver_;
+ // A function that can be set by a client of the Solver to provide indication
+ // that it wants a snapshot saved and/or to exit early.
+ ActionCallback action_request_function_;
+
+ // True iff a request to stop early was received.
+ bool requested_early_exit_;
+
DISABLE_COPY_AND_ASSIGN(Solver);
};
diff --git a/include/caffe/util/signal_handler.h b/include/caffe/util/signal_handler.h
new file mode 100644
index 00000000..fb84c65b
--- /dev/null
+++ b/include/caffe/util/signal_handler.h
@@ -0,0 +1,24 @@
+#ifndef INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
+#define INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
+
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/solver.hpp"
+
+namespace caffe {
+
+class SignalHandler {
+ public:
+ // Contructor. Specify what action to take when a signal is received.
+ SignalHandler(SolverAction::Enum SIGINT_action,
+ SolverAction::Enum SIGHUP_action);
+ ~SignalHandler();
+ ActionCallback GetActionFunction();
+ private:
+ SolverAction::Enum CheckForSignals() const;
+ SolverAction::Enum SIGINT_action_;
+ SolverAction::Enum SIGHUP_action_;
+};
+
+} // namespace caffe
+
+#endif // INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 9348e11c..394ec3b3 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -17,15 +17,31 @@
namespace caffe {
+template<typename Dtype>
+void Solver<Dtype>::SetActionFunction(ActionCallback func) {
+ action_request_function_ = func;
+}
+
+template<typename Dtype>
+SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
+ if (action_request_function_) {
+ // If the external request function has been set, call it.
+ return action_request_function_();
+ }
+ return SolverAction::NONE;
+}
+
template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
- : net_(), callbacks_(), root_solver_(root_solver) {
+ : net_(), callbacks_(), root_solver_(root_solver),
+ requested_early_exit_(false) {
Init(param);
}
template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
- : net_(), callbacks_(), root_solver_(root_solver) {
+ : net_(), callbacks_(), root_solver_(root_solver),
+ requested_early_exit_(false) {
SolverParameter param;
ReadProtoFromTextFileOrDie(param_file, &param);
Init(param);
@@ -195,6 +211,10 @@ void Solver<Dtype>::Step(int iters) {
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
TestAll();
+ if (requested_early_exit_) {
+ // Break out of the while loop because stop was requested while testing.
+ break;
+ }
}
for (int i = 0; i < callbacks_.size(); ++i) {
@@ -250,12 +270,20 @@ void Solver<Dtype>::Step(int iters) {
// the number of times the weights have been updated.
++iter_;
+ SolverAction::Enum request = GetRequestedAction();
+
// Save a snapshot if needed.
- if (param_.snapshot()
- && iter_ % param_.snapshot() == 0
- && Caffe::root_solver()) {
+ if ((param_.snapshot()
+ && iter_ % param_.snapshot() == 0
+ && Caffe::root_solver()) ||
+ (request == SolverAction::SNAPSHOT)) {
Snapshot();
}
+ if (SolverAction::STOP == request) {
+ requested_early_exit_ = true;
+ // Break out of training loop.
+ break;
+ }
}
}
@@ -265,6 +293,9 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
+ // Initialize to false every time we start solving.
+ requested_early_exit_ = false;
+
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
@@ -279,6 +310,10 @@ void Solver<Dtype>::Solve(const char* resume_file) {
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot();
}
+ if (requested_early_exit_) {
+ LOG(INFO) << "Optimization stopped early.";
+ return;
+ }
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
@@ -296,10 +331,11 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Optimization Done.";
}
-
template <typename Dtype>
void Solver<Dtype>::TestAll() {
- for (int test_net_id = 0; test_net_id < test_nets_.size(); ++test_net_id) {
+ for (int test_net_id = 0;
+ test_net_id < test_nets_.size() && !requested_early_exit_;
+ ++test_net_id) {
Test(test_net_id);
}
}
@@ -317,6 +353,21 @@ void Solver<Dtype>::Test(const int test_net_id) {
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
+ SolverAction::Enum request = GetRequestedAction();
+ // Check to see if stoppage of testing/training has been requested.
+ while (request != SolverAction::NONE) {
+ if (SolverAction::SNAPSHOT == request) {
+ Snapshot();
+ } else if (SolverAction::STOP == request) {
+ requested_early_exit_ = true;
+ }
+ request = GetRequestedAction();
+ }
+ if (requested_early_exit_) {
+ // break out of test loop.
+ break;
+ }
+
Dtype iter_loss;
const vector<Blob<Dtype>*>& result =
test_net->Forward(bottom_vec, &iter_loss);
@@ -341,6 +392,10 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}
}
+ if (requested_early_exit_) {
+ LOG(INFO) << "Test interrupted.";
+ return;
+ }
if (param_.test_compute_loss()) {
loss /= param_.test_iter(test_net_id);
LOG(INFO) << "Test loss: " << loss;
@@ -361,7 +416,6 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}
-
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
CHECK(Caffe::root_solver());
diff --git a/src/caffe/util/signal_handler.cpp b/src/caffe/util/signal_handler.cpp
new file mode 100644
index 00000000..5d764ec5
--- /dev/null
+++ b/src/caffe/util/signal_handler.cpp
@@ -0,0 +1,115 @@
+#include <boost/bind.hpp>
+#include <glog/logging.h>
+
+#include <signal.h>
+#include <csignal>
+
+#include "caffe/util/signal_handler.h"
+
+namespace {
+ static volatile sig_atomic_t got_sigint = false;
+ static volatile sig_atomic_t got_sighup = false;
+ static bool already_hooked_up = false;
+
+ void handle_signal(int signal) {
+ switch (signal) {
+ case SIGHUP:
+ got_sighup = true;
+ break;
+ case SIGINT:
+ got_sigint = true;
+ break;
+ }
+ }
+
+ void HookupHandler() {
+ if (already_hooked_up) {
+ LOG(FATAL) << "Tried to hookup signal handlers more than once.";
+ }
+ already_hooked_up = true;
+
+ struct sigaction sa;
+ // Setup the handler
+ sa.sa_handler = &handle_signal;
+ // Restart the system call, if at all possible
+ sa.sa_flags = SA_RESTART;
+ // Block every signal during the handler
+ sigfillset(&sa.sa_mask);
+ // Intercept SIGHUP and SIGINT
+ if (sigaction(SIGHUP, &sa, NULL) == -1) {
+ LOG(FATAL) << "Cannot install SIGHUP handler.";
+ }
+ if (sigaction(SIGINT, &sa, NULL) == -1) {
+ LOG(FATAL) << "Cannot install SIGINT handler.";
+ }
+ }
+
+ // Set the signal handlers to the default.
+ void UnhookHandler() {
+ if (already_hooked_up) {
+ struct sigaction sa;
+ // Setup the sighub handler
+ sa.sa_handler = SIG_DFL;
+ // Restart the system call, if at all possible
+ sa.sa_flags = SA_RESTART;
+ // Block every signal during the handler
+ sigfillset(&sa.sa_mask);
+ // Intercept SIGHUP and SIGINT
+ if (sigaction(SIGHUP, &sa, NULL) == -1) {
+ LOG(FATAL) << "Cannot uninstall SIGHUP handler.";
+ }
+ if (sigaction(SIGINT, &sa, NULL) == -1) {
+ LOG(FATAL) << "Cannot uninstall SIGINT handler.";
+ }
+
+ already_hooked_up = false;
+ }
+ }
+
+ // Return true iff a SIGINT has been received since the last time this
+ // function was called.
+ bool GotSIGINT() {
+ bool result = got_sigint;
+ got_sigint = false;
+ return result;
+ }
+
+ // Return true iff a SIGHUP has been received since the last time this
+ // function was called.
+ bool GotSIGHUP() {
+ bool result = got_sighup;
+ got_sighup = false;
+ return result;
+ }
+} // namespace
+
+namespace caffe {
+
+SignalHandler::SignalHandler(SolverAction::Enum SIGINT_action,
+ SolverAction::Enum SIGHUP_action):
+ SIGINT_action_(SIGINT_action),
+ SIGHUP_action_(SIGHUP_action) {
+ HookupHandler();
+}
+
+SignalHandler::~SignalHandler() {
+ UnhookHandler();
+}
+
+SolverAction::Enum SignalHandler::CheckForSignals() const {
+ if (GotSIGHUP()) {
+ return SIGHUP_action_;
+ }
+ if (GotSIGINT()) {
+ return SIGINT_action_;
+ }
+ return SolverAction::NONE;
+}
+
+// Return the function that the solver can use to find out if a snapshot or
+// early exit is being requested.
+ActionCallback SignalHandler::GetActionFunction() {
+ return boost::bind(&SignalHandler::CheckForSignals, this);
+}
+
+} // namespace caffe
diff --git a/tools/caffe.cpp b/tools/caffe.cpp
index 9f31b37a..ff63860a 100644
--- a/tools/caffe.cpp
+++ b/tools/caffe.cpp
@@ -12,6 +12,7 @@ namespace bp = boost::python;
#include "boost/algorithm/string.hpp"
#include "caffe/caffe.hpp"
+#include "caffe/util/signal_handler.h"
using caffe::Blob;
using caffe::Caffe;
@@ -39,6 +40,12 @@ DEFINE_string(weights, "",
"separated by ','. Cannot be set simultaneously with snapshot.");
DEFINE_int32(iterations, 50,
"The number of iterations to run.");
+DEFINE_string(sigint_effect, "stop",
+ "Optional; action to take when a SIGINT signal is received: "
+ "snapshot, stop or none.");
+DEFINE_string(sighup_effect, "snapshot",
+ "Optional; action to take when a SIGHUP signal is received: "
+ "snapshot, stop or none.");
// A simple registry for caffe commands.
typedef int (*BrewFunction)();
@@ -126,6 +133,22 @@ void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
}
}
+// Translate the signal effect the user specified on the command-line to the
+// corresponding enumeration.
+caffe::SolverAction::Enum GetRequestedAction(
+ const std::string& flag_value) {
+ if (flag_value == "stop") {
+ return caffe::SolverAction::STOP;
+ }
+ if (flag_value == "snapshot") {
+ return caffe::SolverAction::SNAPSHOT;
+ }
+ if (flag_value == "none") {
+ return caffe::SolverAction::NONE;
+ }
+ LOG(FATAL) << "Invalid signal effect \""<< flag_value << "\" was specified";
+}
+
// Train / Finetune a model.
int train() {
CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train.";
@@ -165,7 +188,14 @@ int train() {
Caffe::set_solver_count(gpus.size());
}
- shared_ptr<Solver<float> > solver(caffe::GetSolver<float>(solver_param));
+ caffe::SignalHandler signal_handler(
+ GetRequestedAction(FLAGS_sigint_effect),
+ GetRequestedAction(FLAGS_sighup_effect));
+
+ shared_ptr<caffe::Solver<float> >
+ solver(caffe::GetSolver<float>(solver_param));
+
+ solver->SetActionFunction(signal_handler.GetActionFunction());
if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;