summaryrefslogtreecommitdiff
path: root/src/caffe
diff options
context:
space:
mode:
authorRonghang Hu <huronghang@hotmail.com>2015-09-26 11:47:02 -0700
committerRonghang Hu <huronghang@hotmail.com>2015-10-16 22:32:33 -0700
commitc1f7fe1cffa4388886b735f49cd915fad905fca4 (patch)
tree53a3aa9ee7973670c6d13e932d036ecc91c141d6 /src/caffe
parent0eea815ad6fa3313888b6229499a237820258deb (diff)
downloadcaffeonacl-c1f7fe1cffa4388886b735f49cd915fad905fca4.tar.gz
caffeonacl-c1f7fe1cffa4388886b735f49cd915fad905fca4.tar.bz2
caffeonacl-c1f7fe1cffa4388886b735f49cd915fad905fca4.zip
Add automatic upgrade for solver type
Diffstat (limited to 'src/caffe')
-rw-r--r--src/caffe/solver.cpp2
-rw-r--r--src/caffe/test/test_upgrade_proto.cpp61
-rw-r--r--src/caffe/util/upgrade_proto.cpp74
3 files changed, 136 insertions, 1 deletions
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 016a0288..d3bc7361 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -36,7 +36,7 @@ Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
: net_(), callbacks_(), root_solver_(root_solver),
requested_early_exit_(false) {
SolverParameter param;
- ReadProtoFromTextFileOrDie(param_file, &param);
+ ReadSolverParamsFromTextFileOrDie(param_file, &param);
Init(param);
}
diff --git a/src/caffe/test/test_upgrade_proto.cpp b/src/caffe/test/test_upgrade_proto.cpp
index ee05b151..df9aeb62 100644
--- a/src/caffe/test/test_upgrade_proto.cpp
+++ b/src/caffe/test/test_upgrade_proto.cpp
@@ -2928,4 +2928,65 @@ TEST_F(NetUpgradeTest, TestUpgradeV1LayerType) {
}
}
#endif // USE_OPENCV
+
+class SolverTypeUpgradeTest : public ::testing::Test {
+ protected:
+ void RunSolverTypeUpgradeTest(
+ const string& input_param_string, const string& output_param_string) {
+ // Test upgrading old solver_type field (enum) to new type field (string)
+ SolverParameter input_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ input_param_string, &input_param));
+ SolverParameter expected_output_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ output_param_string, &expected_output_param));
+ SolverParameter actual_output_param = input_param;
+ UpgradeSolverType(&actual_output_param);
+ EXPECT_EQ(expected_output_param.DebugString(),
+ actual_output_param.DebugString());
+ }
+};
+
+TEST_F(SolverTypeUpgradeTest, TestSimple) {
+ const char* old_type_vec[6] = { "SGD", "ADAGRAD", "NESTEROV", "RMSPROP",
+ "ADADELTA", "ADAM" };
+ const char* new_type_vec[6] = { "SGD", "AdaGrad", "Nesterov", "RMSProp",
+ "AdaDelta", "Adam" };
+ for (int i = 0; i < 6; ++i) {
+ const string& input_proto =
+ "net: 'examples/mnist/lenet_train_test.prototxt' "
+ "test_iter: 100 "
+ "test_interval: 500 "
+ "base_lr: 0.01 "
+ "momentum: 0.0 "
+ "weight_decay: 0.0005 "
+ "lr_policy: 'inv' "
+ "gamma: 0.0001 "
+ "power: 0.75 "
+ "display: 100 "
+ "max_iter: 10000 "
+ "snapshot: 5000 "
+ "snapshot_prefix: 'examples/mnist/lenet_rmsprop' "
+ "solver_mode: GPU "
+ "solver_type: " + std::string(old_type_vec[i]) + " ";
+ const string& expected_output_proto =
+ "net: 'examples/mnist/lenet_train_test.prototxt' "
+ "test_iter: 100 "
+ "test_interval: 500 "
+ "base_lr: 0.01 "
+ "momentum: 0.0 "
+ "weight_decay: 0.0005 "
+ "lr_policy: 'inv' "
+ "gamma: 0.0001 "
+ "power: 0.75 "
+ "display: 100 "
+ "max_iter: 10000 "
+ "snapshot: 5000 "
+ "snapshot_prefix: 'examples/mnist/lenet_rmsprop' "
+ "solver_mode: GPU "
+ "type: '" + std::string(new_type_vec[i]) + "' ";
+ this->RunSolverTypeUpgradeTest(input_proto, expected_output_proto);
+ }
+}
+
} // NOLINT(readability/fn_size) // namespace caffe
diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp
index 6eae9fec..ff3f8ffc 100644
--- a/src/caffe/util/upgrade_proto.cpp
+++ b/src/caffe/util/upgrade_proto.cpp
@@ -937,4 +937,78 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) {
}
}
+// Return true iff the solver contains any old solver_type specified as enums
+bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param) {
+ if (solver_param.has_solver_type()) {
+ return true;
+ }
+ return false;
+}
+
+bool UpgradeSolverType(SolverParameter* solver_param) {
+ CHECK(!solver_param->has_solver_type() || !solver_param->has_type())
+ << "Failed to upgrade solver: old solver_type field (enum) and new type "
+ << "field (string) cannot be both specified in solver proto text.";
+ if (solver_param->has_solver_type()) {
+ string type;
+ switch (solver_param->solver_type()) {
+ case SolverParameter_SolverType_SGD:
+ type = "SGD";
+ break;
+ case SolverParameter_SolverType_NESTEROV:
+ type = "Nesterov";
+ break;
+ case SolverParameter_SolverType_ADAGRAD:
+ type = "AdaGrad";
+ break;
+ case SolverParameter_SolverType_RMSPROP:
+ type = "RMSProp";
+ break;
+ case SolverParameter_SolverType_ADADELTA:
+ type = "AdaDelta";
+ break;
+ case SolverParameter_SolverType_ADAM:
+ type = "Adam";
+ break;
+ default:
+ LOG(FATAL) << "Unknown SolverParameter solver_type: " << type;
+ }
+ solver_param->set_type(type);
+ solver_param->clear_solver_type();
+ } else {
+ LOG(ERROR) << "Warning: solver type already up to date. ";
+ return false;
+ }
+ return true;
+}
+
+// Check for deprecations and upgrade the SolverParameter as needed.
+bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) {
+ bool success = true;
+ // Try to upgrade old style solver_type enum fields into new string type
+ if (SolverNeedsTypeUpgrade(*param)) {
+ LOG(INFO) << "Attempting to upgrade input file specified using deprecated "
+ << "'solver_type' field (enum)': " << param_file;
+ if (!UpgradeSolverType(param)) {
+ success = false;
+ LOG(ERROR) << "Warning: had one or more problems upgrading "
+ << "SolverType (see above).";
+ } else {
+ LOG(INFO) << "Successfully upgraded file specified using deprecated "
+ << "'solver_type' field (enum) to 'type' field (string).";
+ LOG(WARNING) << "Note that future Caffe releases will only support "
+ << "'type' field (string) for a solver's type.";
+ }
+ }
+ return success;
+}
+
+// Read parameters from a file into a SolverParameter proto message.
+void ReadSolverParamsFromTextFileOrDie(const string& param_file,
+ SolverParameter* param) {
+ CHECK(ReadProtoFromTextFile(param_file, param))
+ << "Failed to parse SolverParameter file: " << param_file;
+ UpgradeSolverAsNeeded(param_file, param);
+}
+
} // namespace caffe