diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/caffe/solver.hpp | 37 | ||||
-rw-r--r-- | include/caffe/util/signal_handler.h | 24 |
2 files changed, 60 insertions, 1 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index ab12ef1b..aba3e036 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -1,6 +1,6 @@ #ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_ #define CAFFE_OPTIMIZATION_SOLVER_HPP_ - +#include <boost/function.hpp> #include <string> #include <vector> @@ -9,6 +9,28 @@ namespace caffe { /** + * @brief Enumeration of actions that a client of the Solver may request by + * implementing the Solver's action request function, which a + * a client may optionally provide in order to request early termination + * or saving a snapshot without exiting. In the executable caffe, this + * mechanism is used to allow the snapshot to be saved when stopping + * execution with a SIGINT (Ctrl-C). + */ + namespace SolverAction { + enum Enum { + NONE = 0, // Take no special action. + STOP = 1, // Stop training. snapshot_after_train controls whether a + // snapshot is created. + SNAPSHOT = 2 // Take a snapshot, and keep training. + }; + } + +/** + * @brief Type of a function that returns a Solver Action enumeration. + */ +typedef boost::function<SolverAction::Enum()> ActionCallback; + +/** * @brief An interface for classes that perform optimization on Net%s. * * Requires implementation of ApplyUpdate to compute a parameter update @@ -23,6 +45,12 @@ class Solver { void Init(const SolverParameter& param); void InitTrainNet(); void InitTestNets(); + + // Client of the Solver optionally may call this in order to set the function + // that the solver uses to see what action it should take (e.g. snapshot or + // exit training early). + void SetActionFunction(ActionCallback func); + SolverAction::Enum GetRequestedAction(); // The main entry of the solver function. In default, iter will be zero. Pass // in a non-zero iter number to resume training for a pre-trained net. virtual void Solve(const char* resume_file = NULL); @@ -84,6 +112,13 @@ class Solver { // in data parallelism const Solver* const root_solver_; + // A function that can be set by a client of the Solver to provide indication + // that it wants a snapshot saved and/or to exit early. + ActionCallback action_request_function_; + + // True iff a request to stop early was received. + bool requested_early_exit_; + DISABLE_COPY_AND_ASSIGN(Solver); }; diff --git a/include/caffe/util/signal_handler.h b/include/caffe/util/signal_handler.h new file mode 100644 index 00000000..fb84c65b --- /dev/null +++ b/include/caffe/util/signal_handler.h @@ -0,0 +1,24 @@ +#ifndef INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_ +#define INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_ + +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" + +namespace caffe { + +class SignalHandler { + public: + // Contructor. Specify what action to take when a signal is received. + SignalHandler(SolverAction::Enum SIGINT_action, + SolverAction::Enum SIGHUP_action); + ~SignalHandler(); + ActionCallback GetActionFunction(); + private: + SolverAction::Enum CheckForSignals() const; + SolverAction::Enum SIGINT_action_; + SolverAction::Enum SIGHUP_action_; +}; + +} // namespace caffe + +#endif // INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_ |