diff options
author | Ronghang Hu <huronghang@hotmail.com> | 2015-09-24 19:40:45 -0700 |
---|---|---|
committer | Ronghang Hu <huronghang@hotmail.com> | 2015-10-16 22:32:32 -0700 |
commit | 0eea815ad6fa3313888b6229499a237820258deb (patch) | |
tree | b76664a58993a0b0658a01f24dd3d0f8898ceac1 | |
parent | b822a702d19d4fbebbc91198a991f91c34e60650 (diff) | |
download | caffeonacl-0eea815ad6fa3313888b6229499a237820258deb.tar.gz caffeonacl-0eea815ad6fa3313888b6229499a237820258deb.tar.bz2 caffeonacl-0eea815ad6fa3313888b6229499a237820258deb.zip |
Change solver type to string and provide solver registry
-rw-r--r-- | include/caffe/caffe.hpp | 1 | ||||
-rw-r--r-- | include/caffe/sgd_solvers.hpp | 6 | ||||
-rw-r--r-- | include/caffe/solver.hpp | 9 | ||||
-rw-r--r-- | include/caffe/solver_factory.hpp | 137 | ||||
-rw-r--r-- | src/caffe/proto/caffe.proto | 27 | ||||
-rw-r--r-- | src/caffe/solver_factory.cpp | 32 | ||||
-rw-r--r-- | src/caffe/solvers/adadelta_solver.cpp | 1 | ||||
-rw-r--r-- | src/caffe/solvers/adagrad_solver.cpp | 1 | ||||
-rw-r--r-- | src/caffe/solvers/adam_solver.cpp | 1 | ||||
-rw-r--r-- | src/caffe/solvers/nesterov_solver.cpp | 1 | ||||
-rw-r--r-- | src/caffe/solvers/rmsprop_solver.cpp | 1 | ||||
-rw-r--r-- | src/caffe/solvers/sgd_solver.cpp | 1 | ||||
-rw-r--r-- | src/caffe/test/test_gradient_based_solver.cpp | 54 | ||||
-rw-r--r-- | src/caffe/test/test_solver_factory.cpp | 50 | ||||
-rw-r--r-- | tools/caffe.cpp | 2 |
15 files changed, 233 insertions, 91 deletions
diff --git a/include/caffe/caffe.hpp b/include/caffe/caffe.hpp index 68a5e1d1..bd772830 100644 --- a/include/caffe/caffe.hpp +++ b/include/caffe/caffe.hpp @@ -13,6 +13,7 @@ #include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/solver.hpp" +#include "caffe/solver_factory.hpp" #include "caffe/util/benchmark.hpp" #include "caffe/util/io.hpp" #include "caffe/vision_layers.hpp" diff --git a/include/caffe/sgd_solvers.hpp b/include/caffe/sgd_solvers.hpp index 6bf1d70c..1fc52d87 100644 --- a/include/caffe/sgd_solvers.hpp +++ b/include/caffe/sgd_solvers.hpp @@ -19,6 +19,7 @@ class SGDSolver : public Solver<Dtype> { : Solver<Dtype>(param) { PreSolve(); } explicit SGDSolver(const string& param_file) : Solver<Dtype>(param_file) { PreSolve(); } + virtual inline const char* type() const { return "SGD"; } const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; } @@ -51,6 +52,7 @@ class NesterovSolver : public SGDSolver<Dtype> { : SGDSolver<Dtype>(param) {} explicit NesterovSolver(const string& param_file) : SGDSolver<Dtype>(param_file) {} + virtual inline const char* type() const { return "Nesterov"; } protected: virtual void ComputeUpdateValue(int param_id, Dtype rate); @@ -65,6 +67,7 @@ class AdaGradSolver : public SGDSolver<Dtype> { : SGDSolver<Dtype>(param) { constructor_sanity_check(); } explicit AdaGradSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); } + virtual inline const char* type() const { return "AdaGrad"; } protected: virtual void ComputeUpdateValue(int param_id, Dtype rate); @@ -84,6 +87,7 @@ class RMSPropSolver : public SGDSolver<Dtype> { : SGDSolver<Dtype>(param) { constructor_sanity_check(); } explicit RMSPropSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); } + virtual inline const char* type() const { return "RMSProp"; } protected: virtual void ComputeUpdateValue(int param_id, Dtype rate); @@ -106,6 +110,7 @@ class AdaDeltaSolver : public SGDSolver<Dtype> { : SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); } explicit AdaDeltaSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); } + virtual inline const char* type() const { return "AdaDelta"; } protected: void AdaDeltaPreSolve(); @@ -129,6 +134,7 @@ class AdamSolver : public SGDSolver<Dtype> { : SGDSolver<Dtype>(param) { AdamPreSolve();} explicit AdamSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { AdamPreSolve(); } + virtual inline const char* type() const { return "Adam"; } protected: void AdamPreSolve(); diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index a045ccf2..298a68f3 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -5,6 +5,7 @@ #include <vector> #include "caffe/net.hpp" +#include "caffe/solver_factory.hpp" namespace caffe { @@ -83,6 +84,10 @@ class Solver { } void CheckSnapshotWritePermissions(); + /** + * @brief Returns the solver type. + */ + virtual inline const char* type() const { return ""; } protected: // Make and apply the update value for the current iteration. @@ -148,10 +153,6 @@ class WorkerSolver : public Solver<Dtype> { } }; -// The solver factory function -template <typename Dtype> -Solver<Dtype>* GetSolver(const SolverParameter& param); - } // namespace caffe #endif // CAFFE_SOLVER_HPP_ diff --git a/include/caffe/solver_factory.hpp b/include/caffe/solver_factory.hpp new file mode 100644 index 00000000..cfff721a --- /dev/null +++ b/include/caffe/solver_factory.hpp @@ -0,0 +1,137 @@ +/** + * @brief A solver factory that allows one to register solvers, similar to + * layer factory. During runtime, registered solvers could be called by passing + * a SolverParameter protobuffer to the CreateSolver function: + * + * SolverRegistry<Dtype>::CreateSolver(param); + * + * There are two ways to register a solver. Assuming that we have a solver like: + * + * template <typename Dtype> + * class MyAwesomeSolver : public Solver<Dtype> { + * // your implementations + * }; + * + * and its type is its C++ class name, but without the "Solver" at the end + * ("MyAwesomeSolver" -> "MyAwesome"). + * + * If the solver is going to be created simply by its constructor, in your c++ + * file, add the following line: + * + * REGISTER_SOLVER_CLASS(MyAwesome); + * + * Or, if the solver is going to be created by another creator function, in the + * format of: + * + * template <typename Dtype> + * Solver<Dtype*> GetMyAwesomeSolver(const SolverParameter& param) { + * // your implementation + * } + * + * then you can register the creator function instead, like + * + * REGISTER_SOLVER_CREATOR(MyAwesome, GetMyAwesomeSolver) + * + * Note that each solver type should only be registered once. + */ + +#ifndef CAFFE_SOLVER_FACTORY_H_ +#define CAFFE_SOLVER_FACTORY_H_ + +#include <map> +#include <string> +#include <vector> + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +template <typename Dtype> +class Solver; + +template <typename Dtype> +class SolverRegistry { + public: + typedef Solver<Dtype>* (*Creator)(const SolverParameter&); + typedef std::map<string, Creator> CreatorRegistry; + + static CreatorRegistry& Registry() { + static CreatorRegistry* g_registry_ = new CreatorRegistry(); + return *g_registry_; + } + + // Adds a creator. + static void AddCreator(const string& type, Creator creator) { + CreatorRegistry& registry = Registry(); + CHECK_EQ(registry.count(type), 0) + << "Solver type " << type << " already registered."; + registry[type] = creator; + } + + // Get a solver using a SolverParameter. + static Solver<Dtype>* CreateSolver(const SolverParameter& param) { + const string& type = param.type(); + CreatorRegistry& registry = Registry(); + CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type + << " (known types: " << SolverTypeListString() << ")"; + return registry[type](param); + } + + static vector<string> SolverTypeList() { + CreatorRegistry& registry = Registry(); + vector<string> solver_types; + for (typename CreatorRegistry::iterator iter = registry.begin(); + iter != registry.end(); ++iter) { + solver_types.push_back(iter->first); + } + return solver_types; + } + + private: + // Solver registry should never be instantiated - everything is done with its + // static variables. + SolverRegistry() {} + + static string SolverTypeListString() { + vector<string> solver_types = SolverTypeList(); + string solver_types_str; + for (vector<string>::iterator iter = solver_types.begin(); + iter != solver_types.end(); ++iter) { + if (iter != solver_types.begin()) { + solver_types_str += ", "; + } + solver_types_str += *iter; + } + return solver_types_str; + } +}; + + +template <typename Dtype> +class SolverRegisterer { + public: + SolverRegisterer(const string& type, + Solver<Dtype>* (*creator)(const SolverParameter&)) { + // LOG(INFO) << "Registering solver type: " << type; + SolverRegistry<Dtype>::AddCreator(type, creator); + } +}; + + +#define REGISTER_SOLVER_CREATOR(type, creator) \ + static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \ + static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \ + +#define REGISTER_SOLVER_CLASS(type) \ + template <typename Dtype> \ + Solver<Dtype>* Creator_##type##Solver( \ + const SolverParameter& param) \ + { \ + return new type##Solver<Dtype>(param); \ + } \ + REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver) + +} // namespace caffe + +#endif // CAFFE_SOLVER_FACTORY_H_ diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 4794991f..76c869c1 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -98,7 +98,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 40 (last added: momentum2) +// SolverParameter next available ID: 41 (last added: type) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -209,16 +209,9 @@ message SolverParameter { // (and by default) initialize using a seed derived from the system clock. optional int64 random_seed = 20 [default = -1]; - // Solver type - enum SolverType { - SGD = 0; - NESTEROV = 1; - ADAGRAD = 2; - RMSPROP = 3; - ADADELTA = 4; - ADAM = 5; - } - optional SolverType solver_type = 30 [default = SGD]; + // type of the solver + optional string type = 40 [default = "SGD"]; + // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam optional float delta = 31 [default = 1e-8]; // parameters for the Adam solver @@ -234,6 +227,18 @@ message SolverParameter { // If false, don't save a snapshot after training finishes. optional bool snapshot_after_train = 28 [default = true]; + + // DEPRECATED: old solver enum types, use string instead + enum SolverType { + SGD = 0; + NESTEROV = 1; + ADAGRAD = 2; + RMSPROP = 3; + ADADELTA = 4; + ADAM = 5; + } + // DEPRECATED: use type instead of solver_type + optional SolverType solver_type = 30 [default = SGD]; } // A message that stores the solver snapshots diff --git a/src/caffe/solver_factory.cpp b/src/caffe/solver_factory.cpp deleted file mode 100644 index f78fab28..00000000 --- a/src/caffe/solver_factory.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include "caffe/solver.hpp" -#include "caffe/sgd_solvers.hpp" - -namespace caffe { - -template <typename Dtype> -Solver<Dtype>* GetSolver(const SolverParameter& param) { - SolverParameter_SolverType type = param.solver_type(); - - switch (type) { - case SolverParameter_SolverType_SGD: - return new SGDSolver<Dtype>(param); - case SolverParameter_SolverType_NESTEROV: - return new NesterovSolver<Dtype>(param); - case SolverParameter_SolverType_ADAGRAD: - return new AdaGradSolver<Dtype>(param); - case SolverParameter_SolverType_RMSPROP: - return new RMSPropSolver<Dtype>(param); - case SolverParameter_SolverType_ADADELTA: - return new AdaDeltaSolver<Dtype>(param); - case SolverParameter_SolverType_ADAM: - return new AdamSolver<Dtype>(param); - default: - LOG(FATAL) << "Unknown SolverType: " << type; - } - return (Solver<Dtype>*) NULL; -} - -template Solver<float>* GetSolver(const SolverParameter& param); -template Solver<double>* GetSolver(const SolverParameter& param); - -} // namespace caffe diff --git a/src/caffe/solvers/adadelta_solver.cpp b/src/caffe/solvers/adadelta_solver.cpp index 45cd4eb2..a37899eb 100644 --- a/src/caffe/solvers/adadelta_solver.cpp +++ b/src/caffe/solvers/adadelta_solver.cpp @@ -151,5 +151,6 @@ void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { } INSTANTIATE_CLASS(AdaDeltaSolver); +REGISTER_SOLVER_CLASS(AdaDelta); } // namespace caffe diff --git a/src/caffe/solvers/adagrad_solver.cpp b/src/caffe/solvers/adagrad_solver.cpp index 627d816a..5e406326 100644 --- a/src/caffe/solvers/adagrad_solver.cpp +++ b/src/caffe/solvers/adagrad_solver.cpp @@ -84,5 +84,6 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { } INSTANTIATE_CLASS(AdaGradSolver); +REGISTER_SOLVER_CLASS(AdaGrad); } // namespace caffe diff --git a/src/caffe/solvers/adam_solver.cpp b/src/caffe/solvers/adam_solver.cpp index 8c334f66..cb0fbfe2 100644 --- a/src/caffe/solvers/adam_solver.cpp +++ b/src/caffe/solvers/adam_solver.cpp @@ -108,5 +108,6 @@ void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { } INSTANTIATE_CLASS(AdamSolver); +REGISTER_SOLVER_CLASS(Adam); } // namespace caffe diff --git a/src/caffe/solvers/nesterov_solver.cpp b/src/caffe/solvers/nesterov_solver.cpp index 8135ee2c..34bf01eb 100644 --- a/src/caffe/solvers/nesterov_solver.cpp +++ b/src/caffe/solvers/nesterov_solver.cpp @@ -66,5 +66,6 @@ void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { } INSTANTIATE_CLASS(NesterovSolver); +REGISTER_SOLVER_CLASS(Nesterov); } // namespace caffe diff --git a/src/caffe/solvers/rmsprop_solver.cpp b/src/caffe/solvers/rmsprop_solver.cpp index 96d1b3dd..c6247676 100644 --- a/src/caffe/solvers/rmsprop_solver.cpp +++ b/src/caffe/solvers/rmsprop_solver.cpp @@ -80,5 +80,6 @@ void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { } INSTANTIATE_CLASS(RMSPropSolver); +REGISTER_SOLVER_CLASS(RMSProp); } // namespace caffe diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp index 89ef5ec4..32bf19b1 100644 --- a/src/caffe/solvers/sgd_solver.cpp +++ b/src/caffe/solvers/sgd_solver.cpp @@ -343,5 +343,6 @@ void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) { } INSTANTIATE_CLASS(SGDSolver); +REGISTER_SOLVER_CLASS(SGD); } // namespace caffe diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 1767ad3f..84c6747f 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -47,7 +47,6 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { // Test data: check out generate_sample_data.py in the same directory. string* input_file_; - virtual SolverParameter_SolverType solver_type() = 0; virtual void InitSolver(const SolverParameter& param) = 0; virtual void InitSolverFromProtoString(const string& proto) { @@ -290,8 +289,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); // Finally, compute update. const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history(); - if (solver_type() != SolverParameter_SolverType_ADADELTA - && solver_type() != SolverParameter_SolverType_ADAM) { + if (solver_->type() != string("AdaDelta") + && solver_->type() != string("Adam")) { ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias } else { ASSERT_EQ(4, history.size()); // additional blobs for update history @@ -300,26 +299,19 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { const Dtype history_value = (i == D) ? history[1]->cpu_data()[0] : history[0]->cpu_data()[i]; const Dtype temp = momentum * history_value; - switch (solver_type()) { - case SolverParameter_SolverType_SGD: + if (solver_->type() == string("SGD")) { update_value += temp; - break; - case SolverParameter_SolverType_NESTEROV: + } else if (solver_->type() == string("Nesterov")) { update_value += temp; // step back then over-step update_value = (1 + momentum) * update_value - temp; - break; - case SolverParameter_SolverType_ADAGRAD: + } else if (solver_->type() == string("AdaGrad")) { update_value /= std::sqrt(history_value + grad * grad) + delta_; - break; - case SolverParameter_SolverType_RMSPROP: { + } else if (solver_->type() == string("RMSProp")) { const Dtype rms_decay = 0.95; update_value /= std::sqrt(rms_decay*history_value + grad * grad * (1 - rms_decay)) + delta_; - } - break; - case SolverParameter_SolverType_ADADELTA: - { + } else if (solver_->type() == string("AdaDelta")) { const Dtype update_history_value = (i == D) ? history[1 + num_param_blobs]->cpu_data()[0] : history[0 + num_param_blobs]->cpu_data()[i]; @@ -330,9 +322,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { // not actually needed, just here for illustrative purposes // const Dtype weighted_update_average = // momentum * update_history_value + (1 - momentum) * (update_value); - break; - } - case SolverParameter_SolverType_ADAM: { + } else if (solver_->type() == string("Adam")) { const Dtype momentum2 = 0.999; const Dtype m = history_value; const Dtype v = (i == D) ? @@ -344,10 +334,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { std::sqrt(Dtype(1) - pow(momentum2, num_iters)) / (Dtype(1.) - pow(momentum, num_iters)); update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_); - break; - } - default: - LOG(FATAL) << "Unknown solver type: " << solver_type(); + } else { + LOG(FATAL) << "Unknown solver type: " << solver_->type(); } if (i == D) { updated_bias.mutable_cpu_diff()[0] = update_value; @@ -392,7 +380,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin); // Check the solver's history -- should contain the previous update value. - if (solver_type() == SolverParameter_SolverType_SGD) { + if (solver_->type() == string("SGD")) { const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history(); ASSERT_EQ(2, history.size()); for (int i = 0; i < D; ++i) { @@ -581,10 +569,6 @@ class SGDSolverTest : public GradientBasedSolverTest<TypeParam> { virtual void InitSolver(const SolverParameter& param) { this->solver_.reset(new SGDSolver<Dtype>(param)); } - - virtual SolverParameter_SolverType solver_type() { - return SolverParameter_SolverType_SGD; - } }; TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices); @@ -721,9 +705,6 @@ class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> { virtual void InitSolver(const SolverParameter& param) { this->solver_.reset(new AdaGradSolver<Dtype>(param)); } - virtual SolverParameter_SolverType solver_type() { - return SolverParameter_SolverType_ADAGRAD; - } }; TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices); @@ -824,9 +805,6 @@ class NesterovSolverTest : public GradientBasedSolverTest<TypeParam> { virtual void InitSolver(const SolverParameter& param) { this->solver_.reset(new NesterovSolver<Dtype>(param)); } - virtual SolverParameter_SolverType solver_type() { - return SolverParameter_SolverType_NESTEROV; - } }; TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices); @@ -960,10 +938,6 @@ class AdaDeltaSolverTest : public GradientBasedSolverTest<TypeParam> { virtual void InitSolver(const SolverParameter& param) { this->solver_.reset(new AdaDeltaSolver<Dtype>(param)); } - - virtual SolverParameter_SolverType solver_type() { - return SolverParameter_SolverType_ADADELTA; - } }; TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices); @@ -1098,9 +1072,6 @@ class AdamSolverTest : public GradientBasedSolverTest<TypeParam> { new_param.set_momentum2(momentum2); this->solver_.reset(new AdamSolver<Dtype>(new_param)); } - virtual SolverParameter_SolverType solver_type() { - return SolverParameter_SolverType_ADAM; - } }; TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices); @@ -1201,9 +1172,6 @@ class RMSPropSolverTest : public GradientBasedSolverTest<TypeParam> { new_param.set_rms_decay(rms_decay); this->solver_.reset(new RMSPropSolver<Dtype>(new_param)); } - virtual SolverParameter_SolverType solver_type() { - return SolverParameter_SolverType_RMSPROP; - } }; TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices); diff --git a/src/caffe/test/test_solver_factory.cpp b/src/caffe/test/test_solver_factory.cpp new file mode 100644 index 00000000..eef5290f --- /dev/null +++ b/src/caffe/test/test_solver_factory.cpp @@ -0,0 +1,50 @@ +#include <map> +#include <string> + +#include "boost/scoped_ptr.hpp" +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" + +#include "caffe/common.hpp" +#include "caffe/solver.hpp" +#include "caffe/solver_factory.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template <typename TypeParam> +class SolverFactoryTest : public MultiDeviceTest<TypeParam> { + protected: + SolverParameter simple_solver_param() { + const string solver_proto = + "train_net_param { " + " layer { " + " name: 'data' type: 'DummyData' top: 'data' " + " dummy_data_param { shape { dim: 1 } } " + " } " + "} "; + SolverParameter solver_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + solver_proto, &solver_param)); + return solver_param; + } +}; + +TYPED_TEST_CASE(SolverFactoryTest, TestDtypesAndDevices); + +TYPED_TEST(SolverFactoryTest, TestCreateSolver) { + typedef typename TypeParam::Dtype Dtype; + typename SolverRegistry<Dtype>::CreatorRegistry& registry = + SolverRegistry<Dtype>::Registry(); + shared_ptr<Solver<Dtype> > solver; + SolverParameter solver_param = this->simple_solver_param(); + for (typename SolverRegistry<Dtype>::CreatorRegistry::iterator iter = + registry.begin(); iter != registry.end(); ++iter) { + solver_param.set_type(iter->first); + solver.reset(SolverRegistry<Dtype>::CreateSolver(solver_param)); + EXPECT_EQ(iter->first, solver->type()); + } +} + +} // namespace caffe diff --git a/tools/caffe.cpp b/tools/caffe.cpp index e3f684b5..1cb6ad89 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -194,7 +194,7 @@ int train() { GetRequestedAction(FLAGS_sighup_effect)); shared_ptr<caffe::Solver<float> > - solver(caffe::GetSolver<float>(solver_param)); + solver(caffe::SolverRegistry<float>::CreateSolver(solver_param)); solver->SetActionFunction(signal_handler.GetActionFunction()); |