summaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
Diffstat (limited to 'include')
-rw-r--r--include/caffe/solver.hpp37
-rw-r--r--include/caffe/util/signal_handler.h24
2 files changed, 60 insertions, 1 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
index ab12ef1b..aba3e036 100644
--- a/include/caffe/solver.hpp
+++ b/include/caffe/solver.hpp
@@ -1,6 +1,6 @@
#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
#define CAFFE_OPTIMIZATION_SOLVER_HPP_
-
+#include <boost/function.hpp>
#include <string>
#include <vector>
@@ -9,6 +9,28 @@
namespace caffe {
/**
+ * @brief Enumeration of actions that a client of the Solver may request by
+ * implementing the Solver's action request function, which a
+ * a client may optionally provide in order to request early termination
+ * or saving a snapshot without exiting. In the executable caffe, this
+ * mechanism is used to allow the snapshot to be saved when stopping
+ * execution with a SIGINT (Ctrl-C).
+ */
+ namespace SolverAction {
+ enum Enum {
+ NONE = 0, // Take no special action.
+ STOP = 1, // Stop training. snapshot_after_train controls whether a
+ // snapshot is created.
+ SNAPSHOT = 2 // Take a snapshot, and keep training.
+ };
+ }
+
+/**
+ * @brief Type of a function that returns a Solver Action enumeration.
+ */
+typedef boost::function<SolverAction::Enum()> ActionCallback;
+
+/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ApplyUpdate to compute a parameter update
@@ -23,6 +45,12 @@ class Solver {
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();
+
+ // Client of the Solver optionally may call this in order to set the function
+ // that the solver uses to see what action it should take (e.g. snapshot or
+ // exit training early).
+ void SetActionFunction(ActionCallback func);
+ SolverAction::Enum GetRequestedAction();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
@@ -84,6 +112,13 @@ class Solver {
// in data parallelism
const Solver* const root_solver_;
+ // A function that can be set by a client of the Solver to provide indication
+ // that it wants a snapshot saved and/or to exit early.
+ ActionCallback action_request_function_;
+
+ // True iff a request to stop early was received.
+ bool requested_early_exit_;
+
DISABLE_COPY_AND_ASSIGN(Solver);
};
diff --git a/include/caffe/util/signal_handler.h b/include/caffe/util/signal_handler.h
new file mode 100644
index 00000000..fb84c65b
--- /dev/null
+++ b/include/caffe/util/signal_handler.h
@@ -0,0 +1,24 @@
+#ifndef INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
+#define INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
+
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/solver.hpp"
+
+namespace caffe {
+
+class SignalHandler {
+ public:
+ // Contructor. Specify what action to take when a signal is received.
+ SignalHandler(SolverAction::Enum SIGINT_action,
+ SolverAction::Enum SIGHUP_action);
+ ~SignalHandler();
+ ActionCallback GetActionFunction();
+ private:
+ SolverAction::Enum CheckForSignals() const;
+ SolverAction::Enum SIGINT_action_;
+ SolverAction::Enum SIGHUP_action_;
+};
+
+} // namespace caffe
+
+#endif // INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_