summaryrefslogtreecommitdiff
path: root/include/caffe/solver.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/caffe/solver.hpp')
-rw-r--r--include/caffe/solver.hpp37
1 files changed, 36 insertions, 1 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
index ab12ef1b..aba3e036 100644
--- a/include/caffe/solver.hpp
+++ b/include/caffe/solver.hpp
@@ -1,6 +1,6 @@
#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
#define CAFFE_OPTIMIZATION_SOLVER_HPP_
-
+#include <boost/function.hpp>
#include <string>
#include <vector>
@@ -9,6 +9,28 @@
namespace caffe {
/**
+ * @brief Enumeration of actions that a client of the Solver may request by
+ * implementing the Solver's action request function, which a
+ * a client may optionally provide in order to request early termination
+ * or saving a snapshot without exiting. In the executable caffe, this
+ * mechanism is used to allow the snapshot to be saved when stopping
+ * execution with a SIGINT (Ctrl-C).
+ */
+ namespace SolverAction {
+ enum Enum {
+ NONE = 0, // Take no special action.
+ STOP = 1, // Stop training. snapshot_after_train controls whether a
+ // snapshot is created.
+ SNAPSHOT = 2 // Take a snapshot, and keep training.
+ };
+ }
+
+/**
+ * @brief Type of a function that returns a Solver Action enumeration.
+ */
+typedef boost::function<SolverAction::Enum()> ActionCallback;
+
+/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ApplyUpdate to compute a parameter update
@@ -23,6 +45,12 @@ class Solver {
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();
+
+ // Client of the Solver optionally may call this in order to set the function
+ // that the solver uses to see what action it should take (e.g. snapshot or
+ // exit training early).
+ void SetActionFunction(ActionCallback func);
+ SolverAction::Enum GetRequestedAction();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
@@ -84,6 +112,13 @@ class Solver {
// in data parallelism
const Solver* const root_solver_;
+ // A function that can be set by a client of the Solver to provide indication
+ // that it wants a snapshot saved and/or to exit early.
+ ActionCallback action_request_function_;
+
+ // True iff a request to stop early was received.
+ bool requested_early_exit_;
+
DISABLE_COPY_AND_ASSIGN(Solver);
};