summaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorRonghang Hu <huronghang@hotmail.com>2015-09-24 19:40:45 -0700
committerRonghang Hu <huronghang@hotmail.com>2015-10-16 22:32:32 -0700
commit0eea815ad6fa3313888b6229499a237820258deb (patch)
treeb76664a58993a0b0658a01f24dd3d0f8898ceac1 /include
parentb822a702d19d4fbebbc91198a991f91c34e60650 (diff)
downloadcaffeonacl-0eea815ad6fa3313888b6229499a237820258deb.tar.gz
caffeonacl-0eea815ad6fa3313888b6229499a237820258deb.tar.bz2
caffeonacl-0eea815ad6fa3313888b6229499a237820258deb.zip
Change solver type to string and provide solver registry
Diffstat (limited to 'include')
-rw-r--r--include/caffe/caffe.hpp1
-rw-r--r--include/caffe/sgd_solvers.hpp6
-rw-r--r--include/caffe/solver.hpp9
-rw-r--r--include/caffe/solver_factory.hpp137
4 files changed, 149 insertions, 4 deletions
diff --git a/include/caffe/caffe.hpp b/include/caffe/caffe.hpp
index 68a5e1d1..bd772830 100644
--- a/include/caffe/caffe.hpp
+++ b/include/caffe/caffe.hpp
@@ -13,6 +13,7 @@
#include "caffe/parallel.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/solver.hpp"
+#include "caffe/solver_factory.hpp"
#include "caffe/util/benchmark.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"
diff --git a/include/caffe/sgd_solvers.hpp b/include/caffe/sgd_solvers.hpp
index 6bf1d70c..1fc52d87 100644
--- a/include/caffe/sgd_solvers.hpp
+++ b/include/caffe/sgd_solvers.hpp
@@ -19,6 +19,7 @@ class SGDSolver : public Solver<Dtype> {
: Solver<Dtype>(param) { PreSolve(); }
explicit SGDSolver(const string& param_file)
: Solver<Dtype>(param_file) { PreSolve(); }
+ virtual inline const char* type() const { return "SGD"; }
const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
@@ -51,6 +52,7 @@ class NesterovSolver : public SGDSolver<Dtype> {
: SGDSolver<Dtype>(param) {}
explicit NesterovSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) {}
+ virtual inline const char* type() const { return "Nesterov"; }
protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
@@ -65,6 +67,7 @@ class AdaGradSolver : public SGDSolver<Dtype> {
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
explicit AdaGradSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
+ virtual inline const char* type() const { return "AdaGrad"; }
protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
@@ -84,6 +87,7 @@ class RMSPropSolver : public SGDSolver<Dtype> {
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
explicit RMSPropSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
+ virtual inline const char* type() const { return "RMSProp"; }
protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
@@ -106,6 +110,7 @@ class AdaDeltaSolver : public SGDSolver<Dtype> {
: SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
explicit AdaDeltaSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }
+ virtual inline const char* type() const { return "AdaDelta"; }
protected:
void AdaDeltaPreSolve();
@@ -129,6 +134,7 @@ class AdamSolver : public SGDSolver<Dtype> {
: SGDSolver<Dtype>(param) { AdamPreSolve();}
explicit AdamSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { AdamPreSolve(); }
+ virtual inline const char* type() const { return "Adam"; }
protected:
void AdamPreSolve();
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
index a045ccf2..298a68f3 100644
--- a/include/caffe/solver.hpp
+++ b/include/caffe/solver.hpp
@@ -5,6 +5,7 @@
#include <vector>
#include "caffe/net.hpp"
+#include "caffe/solver_factory.hpp"
namespace caffe {
@@ -83,6 +84,10 @@ class Solver {
}
void CheckSnapshotWritePermissions();
+ /**
+ * @brief Returns the solver type.
+ */
+ virtual inline const char* type() const { return ""; }
protected:
// Make and apply the update value for the current iteration.
@@ -148,10 +153,6 @@ class WorkerSolver : public Solver<Dtype> {
}
};
-// The solver factory function
-template <typename Dtype>
-Solver<Dtype>* GetSolver(const SolverParameter& param);
-
} // namespace caffe
#endif // CAFFE_SOLVER_HPP_
diff --git a/include/caffe/solver_factory.hpp b/include/caffe/solver_factory.hpp
new file mode 100644
index 00000000..cfff721a
--- /dev/null
+++ b/include/caffe/solver_factory.hpp
@@ -0,0 +1,137 @@
+/**
+ * @brief A solver factory that allows one to register solvers, similar to
+ * layer factory. During runtime, registered solvers could be called by passing
+ * a SolverParameter protobuffer to the CreateSolver function:
+ *
+ * SolverRegistry<Dtype>::CreateSolver(param);
+ *
+ * There are two ways to register a solver. Assuming that we have a solver like:
+ *
+ * template <typename Dtype>
+ * class MyAwesomeSolver : public Solver<Dtype> {
+ * // your implementations
+ * };
+ *
+ * and its type is its C++ class name, but without the "Solver" at the end
+ * ("MyAwesomeSolver" -> "MyAwesome").
+ *
+ * If the solver is going to be created simply by its constructor, in your c++
+ * file, add the following line:
+ *
+ * REGISTER_SOLVER_CLASS(MyAwesome);
+ *
+ * Or, if the solver is going to be created by another creator function, in the
+ * format of:
+ *
+ * template <typename Dtype>
+ * Solver<Dtype*> GetMyAwesomeSolver(const SolverParameter& param) {
+ * // your implementation
+ * }
+ *
+ * then you can register the creator function instead, like
+ *
+ * REGISTER_SOLVER_CREATOR(MyAwesome, GetMyAwesomeSolver)
+ *
+ * Note that each solver type should only be registered once.
+ */
+
+#ifndef CAFFE_SOLVER_FACTORY_H_
+#define CAFFE_SOLVER_FACTORY_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "caffe/common.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+namespace caffe {
+
+template <typename Dtype>
+class Solver;
+
+template <typename Dtype>
+class SolverRegistry {
+ public:
+ typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
+ typedef std::map<string, Creator> CreatorRegistry;
+
+ static CreatorRegistry& Registry() {
+ static CreatorRegistry* g_registry_ = new CreatorRegistry();
+ return *g_registry_;
+ }
+
+ // Adds a creator.
+ static void AddCreator(const string& type, Creator creator) {
+ CreatorRegistry& registry = Registry();
+ CHECK_EQ(registry.count(type), 0)
+ << "Solver type " << type << " already registered.";
+ registry[type] = creator;
+ }
+
+ // Get a solver using a SolverParameter.
+ static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
+ const string& type = param.type();
+ CreatorRegistry& registry = Registry();
+ CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
+ << " (known types: " << SolverTypeListString() << ")";
+ return registry[type](param);
+ }
+
+ static vector<string> SolverTypeList() {
+ CreatorRegistry& registry = Registry();
+ vector<string> solver_types;
+ for (typename CreatorRegistry::iterator iter = registry.begin();
+ iter != registry.end(); ++iter) {
+ solver_types.push_back(iter->first);
+ }
+ return solver_types;
+ }
+
+ private:
+ // Solver registry should never be instantiated - everything is done with its
+ // static variables.
+ SolverRegistry() {}
+
+ static string SolverTypeListString() {
+ vector<string> solver_types = SolverTypeList();
+ string solver_types_str;
+ for (vector<string>::iterator iter = solver_types.begin();
+ iter != solver_types.end(); ++iter) {
+ if (iter != solver_types.begin()) {
+ solver_types_str += ", ";
+ }
+ solver_types_str += *iter;
+ }
+ return solver_types_str;
+ }
+};
+
+
+template <typename Dtype>
+class SolverRegisterer {
+ public:
+ SolverRegisterer(const string& type,
+ Solver<Dtype>* (*creator)(const SolverParameter&)) {
+ // LOG(INFO) << "Registering solver type: " << type;
+ SolverRegistry<Dtype>::AddCreator(type, creator);
+ }
+};
+
+
+#define REGISTER_SOLVER_CREATOR(type, creator) \
+ static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \
+ static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \
+
+#define REGISTER_SOLVER_CLASS(type) \
+ template <typename Dtype> \
+ Solver<Dtype>* Creator_##type##Solver( \
+ const SolverParameter& param) \
+ { \
+ return new type##Solver<Dtype>(param); \
+ } \
+ REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
+
+} // namespace caffe
+
+#endif // CAFFE_SOLVER_FACTORY_H_