diff options
author | Ronghang Hu <huronghang@hotmail.com> | 2015-09-26 11:47:02 -0700 |
---|---|---|
committer | Ronghang Hu <huronghang@hotmail.com> | 2015-10-16 22:32:33 -0700 |
commit | c1f7fe1cffa4388886b735f49cd915fad905fca4 (patch) | |
tree | 53a3aa9ee7973670c6d13e932d036ecc91c141d6 /src/caffe | |
parent | 0eea815ad6fa3313888b6229499a237820258deb (diff) | |
download | caffeonacl-c1f7fe1cffa4388886b735f49cd915fad905fca4.tar.gz caffeonacl-c1f7fe1cffa4388886b735f49cd915fad905fca4.tar.bz2 caffeonacl-c1f7fe1cffa4388886b735f49cd915fad905fca4.zip |
Add automatic upgrade for solver type
Diffstat (limited to 'src/caffe')
-rw-r--r-- | src/caffe/solver.cpp | 2 | ||||
-rw-r--r-- | src/caffe/test/test_upgrade_proto.cpp | 61 | ||||
-rw-r--r-- | src/caffe/util/upgrade_proto.cpp | 74 |
3 files changed, 136 insertions, 1 deletions
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 016a0288..d3bc7361 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -36,7 +36,7 @@ Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver) : net_(), callbacks_(), root_solver_(root_solver), requested_early_exit_(false) { SolverParameter param; - ReadProtoFromTextFileOrDie(param_file, ¶m); + ReadSolverParamsFromTextFileOrDie(param_file, ¶m); Init(param); } diff --git a/src/caffe/test/test_upgrade_proto.cpp b/src/caffe/test/test_upgrade_proto.cpp index ee05b151..df9aeb62 100644 --- a/src/caffe/test/test_upgrade_proto.cpp +++ b/src/caffe/test/test_upgrade_proto.cpp @@ -2928,4 +2928,65 @@ TEST_F(NetUpgradeTest, TestUpgradeV1LayerType) { } } #endif // USE_OPENCV + +class SolverTypeUpgradeTest : public ::testing::Test { + protected: + void RunSolverTypeUpgradeTest( + const string& input_param_string, const string& output_param_string) { + // Test upgrading old solver_type field (enum) to new type field (string) + SolverParameter input_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + input_param_string, &input_param)); + SolverParameter expected_output_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + output_param_string, &expected_output_param)); + SolverParameter actual_output_param = input_param; + UpgradeSolverType(&actual_output_param); + EXPECT_EQ(expected_output_param.DebugString(), + actual_output_param.DebugString()); + } +}; + +TEST_F(SolverTypeUpgradeTest, TestSimple) { + const char* old_type_vec[6] = { "SGD", "ADAGRAD", "NESTEROV", "RMSPROP", + "ADADELTA", "ADAM" }; + const char* new_type_vec[6] = { "SGD", "AdaGrad", "Nesterov", "RMSProp", + "AdaDelta", "Adam" }; + for (int i = 0; i < 6; ++i) { + const string& input_proto = + "net: 'examples/mnist/lenet_train_test.prototxt' " + "test_iter: 100 " + "test_interval: 500 " + "base_lr: 0.01 " + "momentum: 0.0 " + "weight_decay: 0.0005 " + "lr_policy: 'inv' " + "gamma: 0.0001 " + "power: 0.75 " + "display: 100 " + "max_iter: 10000 " + "snapshot: 5000 " + "snapshot_prefix: 'examples/mnist/lenet_rmsprop' " + "solver_mode: GPU " + "solver_type: " + std::string(old_type_vec[i]) + " "; + const string& expected_output_proto = + "net: 'examples/mnist/lenet_train_test.prototxt' " + "test_iter: 100 " + "test_interval: 500 " + "base_lr: 0.01 " + "momentum: 0.0 " + "weight_decay: 0.0005 " + "lr_policy: 'inv' " + "gamma: 0.0001 " + "power: 0.75 " + "display: 100 " + "max_iter: 10000 " + "snapshot: 5000 " + "snapshot_prefix: 'examples/mnist/lenet_rmsprop' " + "solver_mode: GPU " + "type: '" + std::string(new_type_vec[i]) + "' "; + this->RunSolverTypeUpgradeTest(input_proto, expected_output_proto); + } +} + } // NOLINT(readability/fn_size) // namespace caffe diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index 6eae9fec..ff3f8ffc 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -937,4 +937,78 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) { } } +// Return true iff the solver contains any old solver_type specified as enums +bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param) { + if (solver_param.has_solver_type()) { + return true; + } + return false; +} + +bool UpgradeSolverType(SolverParameter* solver_param) { + CHECK(!solver_param->has_solver_type() || !solver_param->has_type()) + << "Failed to upgrade solver: old solver_type field (enum) and new type " + << "field (string) cannot be both specified in solver proto text."; + if (solver_param->has_solver_type()) { + string type; + switch (solver_param->solver_type()) { + case SolverParameter_SolverType_SGD: + type = "SGD"; + break; + case SolverParameter_SolverType_NESTEROV: + type = "Nesterov"; + break; + case SolverParameter_SolverType_ADAGRAD: + type = "AdaGrad"; + break; + case SolverParameter_SolverType_RMSPROP: + type = "RMSProp"; + break; + case SolverParameter_SolverType_ADADELTA: + type = "AdaDelta"; + break; + case SolverParameter_SolverType_ADAM: + type = "Adam"; + break; + default: + LOG(FATAL) << "Unknown SolverParameter solver_type: " << type; + } + solver_param->set_type(type); + solver_param->clear_solver_type(); + } else { + LOG(ERROR) << "Warning: solver type already up to date. "; + return false; + } + return true; +} + +// Check for deprecations and upgrade the SolverParameter as needed. +bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) { + bool success = true; + // Try to upgrade old style solver_type enum fields into new string type + if (SolverNeedsTypeUpgrade(*param)) { + LOG(INFO) << "Attempting to upgrade input file specified using deprecated " + << "'solver_type' field (enum)': " << param_file; + if (!UpgradeSolverType(param)) { + success = false; + LOG(ERROR) << "Warning: had one or more problems upgrading " + << "SolverType (see above)."; + } else { + LOG(INFO) << "Successfully upgraded file specified using deprecated " + << "'solver_type' field (enum) to 'type' field (string)."; + LOG(WARNING) << "Note that future Caffe releases will only support " + << "'type' field (string) for a solver's type."; + } + } + return success; +} + +// Read parameters from a file into a SolverParameter proto message. +void ReadSolverParamsFromTextFileOrDie(const string& param_file, + SolverParameter* param) { + CHECK(ReadProtoFromTextFile(param_file, param)) + << "Failed to parse SolverParameter file: " << param_file; + UpgradeSolverAsNeeded(param_file, param); +} + } // namespace caffe |