#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_ #define CAFFE_OPTIMIZATION_SOLVER_HPP_ #include #include #include #include "caffe/net.hpp" namespace caffe { /** * @brief Enumeration of actions that a client of the Solver may request by * implementing the Solver's action request function, which a * a client may optionally provide in order to request early termination * or saving a snapshot without exiting. In the executable caffe, this * mechanism is used to allow the snapshot to be saved when stopping * execution with a SIGINT (Ctrl-C). */ namespace SolverAction { enum Enum { NONE = 0, // Take no special action. STOP = 1, // Stop training. snapshot_after_train controls whether a // snapshot is created. SNAPSHOT = 2 // Take a snapshot, and keep training. }; } /** * @brief Type of a function that returns a Solver Action enumeration. */ typedef boost::function ActionCallback; /** * @brief An interface for classes that perform optimization on Net%s. * * Requires implementation of ApplyUpdate to compute a parameter update * given the current state of the Net parameters. */ template class Solver { public: explicit Solver(const SolverParameter& param, const Solver* root_solver = NULL); explicit Solver(const string& param_file, const Solver* root_solver = NULL); void Init(const SolverParameter& param); void InitTrainNet(); void InitTestNets(); // Client of the Solver optionally may call this in order to set the function // that the solver uses to see what action it should take (e.g. snapshot or // exit training early). void SetActionFunction(ActionCallback func); SolverAction::Enum GetRequestedAction(); // The main entry of the solver function. In default, iter will be zero. Pass // in a non-zero iter number to resume training for a pre-trained net. virtual void Solve(const char* resume_file = NULL); inline void Solve(const string resume_file) { Solve(resume_file.c_str()); } void Step(int iters); // The Restore method simply dispatches to one of the // RestoreSolverStateFrom___ protected methods. You should implement these // methods to restore the state from the appropriate snapshot type. void Restore(const char* resume_file); virtual ~Solver() {} inline const SolverParameter& param() const { return param_; } inline shared_ptr > net() { return net_; } inline const vector > >& test_nets() { return test_nets_; } int iter() { return iter_; } // Invoked at specific points during an iteration class Callback { protected: virtual void on_start() = 0; virtual void on_gradients_ready() = 0; template friend class Solver; }; const vector& callbacks() const { return callbacks_; } void add_callback(Callback* value) { callbacks_.push_back(value); } protected: // Make and apply the update value for the current iteration. virtual void ApplyUpdate() = 0; // The Solver::Snapshot function implements the basic snapshotting utility // that stores the learned net. You should implement the SnapshotSolverState() // function that produces a SolverState protocol buffer that needs to be // written to disk together with the learned net. void Snapshot(); string SnapshotFilename(const string extension); string SnapshotToBinaryProto(); string SnapshotToHDF5(); // The test routine void TestAll(); void Test(const int test_net_id = 0); virtual void SnapshotSolverState(const string& model_filename) = 0; virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0; virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0; void DisplayOutputBlobs(const int net_id); SolverParameter param_; int iter_; int current_step_; shared_ptr > net_; vector > > test_nets_; vector callbacks_; // The root solver that holds root nets (actually containing shared layers) // in data parallelism const Solver* const root_solver_; // A function that can be set by a client of the Solver to provide indication // that it wants a snapshot saved and/or to exit early. ActionCallback action_request_function_; // True iff a request to stop early was received. bool requested_early_exit_; DISABLE_COPY_AND_ASSIGN(Solver); }; /** * @brief Solver that only computes gradients, used as worker * for multi-GPU training. */ template class WorkerSolver : public Solver { public: explicit WorkerSolver(const SolverParameter& param, const Solver* root_solver = NULL) : Solver(param, root_solver) {} protected: void ApplyUpdate() {} void SnapshotSolverState(const string& model_filename) { LOG(FATAL) << "Should not be called on worker solver."; } void RestoreSolverStateFromBinaryProto(const string& state_file) { LOG(FATAL) << "Should not be called on worker solver."; } void RestoreSolverStateFromHDF5(const string& state_file) { LOG(FATAL) << "Should not be called on worker solver."; } }; /** * @brief Optimizes the parameters of a Net using * stochastic gradient descent (SGD) with momentum. */ template class SGDSolver : public Solver { public: explicit SGDSolver(const SolverParameter& param) : Solver(param) { PreSolve(); } explicit SGDSolver(const string& param_file) : Solver(param_file) { PreSolve(); } const vector > >& history() { return history_; } protected: void PreSolve(); Dtype GetLearningRate(); virtual void ApplyUpdate(); virtual void Normalize(int param_id); virtual void Regularize(int param_id); virtual void ComputeUpdateValue(int param_id, Dtype rate); virtual void ClipGradients(); virtual void SnapshotSolverState(const string& model_filename); virtual void SnapshotSolverStateToBinaryProto(const string& model_filename); virtual void SnapshotSolverStateToHDF5(const string& model_filename); virtual void RestoreSolverStateFromHDF5(const string& state_file); virtual void RestoreSolverStateFromBinaryProto(const string& state_file); // history maintains the historical momentum data. // update maintains update related data and is not needed in snapshots. // temp maintains other information that might be needed in computation // of gradients/updates and is not needed in snapshots vector > > history_, update_, temp_; DISABLE_COPY_AND_ASSIGN(SGDSolver); }; template class NesterovSolver : public SGDSolver { public: explicit NesterovSolver(const SolverParameter& param) : SGDSolver(param) {} explicit NesterovSolver(const string& param_file) : SGDSolver(param_file) {} protected: virtual void ComputeUpdateValue(int param_id, Dtype rate); DISABLE_COPY_AND_ASSIGN(NesterovSolver); }; template class AdaGradSolver : public SGDSolver { public: explicit AdaGradSolver(const SolverParameter& param) : SGDSolver(param) { constructor_sanity_check(); } explicit AdaGradSolver(const string& param_file) : SGDSolver(param_file) { constructor_sanity_check(); } protected: virtual void ComputeUpdateValue(int param_id, Dtype rate); void constructor_sanity_check() { CHECK_EQ(0, this->param_.momentum()) << "Momentum cannot be used with AdaGrad."; } DISABLE_COPY_AND_ASSIGN(AdaGradSolver); }; template class RMSPropSolver : public SGDSolver { public: explicit RMSPropSolver(const SolverParameter& param) : SGDSolver(param) { constructor_sanity_check(); } explicit RMSPropSolver(const string& param_file) : SGDSolver(param_file) { constructor_sanity_check(); } protected: virtual void ComputeUpdateValue(int param_id, Dtype rate); void constructor_sanity_check() { CHECK_EQ(0, this->param_.momentum()) << "Momentum cannot be used with RMSProp."; CHECK_GE(this->param_.rms_decay(), 0) << "rms_decay should lie between 0 and 1."; CHECK_LT(this->param_.rms_decay(), 1) << "rms_decay should lie between 0 and 1."; } DISABLE_COPY_AND_ASSIGN(RMSPropSolver); }; template class AdaDeltaSolver : public SGDSolver { public: explicit AdaDeltaSolver(const SolverParameter& param) : SGDSolver(param) { AdaDeltaPreSolve(); } explicit AdaDeltaSolver(const string& param_file) : SGDSolver(param_file) { AdaDeltaPreSolve(); } protected: void AdaDeltaPreSolve(); virtual void ComputeUpdateValue(int param_id, Dtype rate); DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); }; /** * @brief AdamSolver, an algorithm for first-order gradient-based optimization * of stochastic objective functions, based on adaptive estimates of * lower-order moments. Described in [1]. * * [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization." * arXiv preprint arXiv:1412.6980v8 (2014). */ template class AdamSolver : public SGDSolver { public: explicit AdamSolver(const SolverParameter& param) : SGDSolver(param) { AdamPreSolve();} explicit AdamSolver(const string& param_file) : SGDSolver(param_file) { AdamPreSolve(); } protected: void AdamPreSolve(); virtual void ComputeUpdateValue(int param_id, Dtype rate); DISABLE_COPY_AND_ASSIGN(AdamSolver); }; template Solver* GetSolver(const SolverParameter& param) { SolverParameter_SolverType type = param.solver_type(); switch (type) { case SolverParameter_SolverType_SGD: return new SGDSolver(param); case SolverParameter_SolverType_NESTEROV: return new NesterovSolver(param); case SolverParameter_SolverType_ADAGRAD: return new AdaGradSolver(param); case SolverParameter_SolverType_RMSPROP: return new RMSPropSolver(param); case SolverParameter_SolverType_ADADELTA: return new AdaDeltaSolver(param); case SolverParameter_SolverType_ADAM: return new AdamSolver(param); default: LOG(FATAL) << "Unknown SolverType: " << type; } return (Solver*) NULL; } } // namespace caffe #endif // CAFFE_OPTIMIZATION_SOLVER_HPP_