summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonghang Hu <huronghang@hotmail.com>2015-09-26 11:47:02 -0700
committerRonghang Hu <huronghang@hotmail.com>2015-10-16 22:32:33 -0700
commitc1f7fe1cffa4388886b735f49cd915fad905fca4 (patch)
tree53a3aa9ee7973670c6d13e932d036ecc91c141d6
parent0eea815ad6fa3313888b6229499a237820258deb (diff)
downloadcaffeonacl-c1f7fe1cffa4388886b735f49cd915fad905fca4.tar.gz
caffeonacl-c1f7fe1cffa4388886b735f49cd915fad905fca4.tar.bz2
caffeonacl-c1f7fe1cffa4388886b735f49cd915fad905fca4.zip
Add automatic upgrade for solver type
-rw-r--r--include/caffe/caffe.hpp1
-rw-r--r--include/caffe/util/upgrade_proto.hpp12
-rw-r--r--matlab/+caffe/private/caffe_.cpp5
-rw-r--r--python/caffe/_caffe.cpp4
-rw-r--r--src/caffe/solver.cpp2
-rw-r--r--src/caffe/test/test_upgrade_proto.cpp61
-rw-r--r--src/caffe/util/upgrade_proto.cpp74
-rw-r--r--tools/caffe.cpp2
-rw-r--r--tools/upgrade_solver_proto_text.cpp50
9 files changed, 206 insertions, 5 deletions
diff --git a/include/caffe/caffe.hpp b/include/caffe/caffe.hpp
index bd772830..a339efba 100644
--- a/include/caffe/caffe.hpp
+++ b/include/caffe/caffe.hpp
@@ -16,6 +16,7 @@
#include "caffe/solver_factory.hpp"
#include "caffe/util/benchmark.hpp"
#include "caffe/util/io.hpp"
+#include "caffe/util/upgrade_proto.hpp"
#include "caffe/vision_layers.hpp"
#endif // CAFFE_CAFFE_HPP_
diff --git a/include/caffe/util/upgrade_proto.hpp b/include/caffe/util/upgrade_proto.hpp
index 6a141843..c94bb3ca 100644
--- a/include/caffe/util/upgrade_proto.hpp
+++ b/include/caffe/util/upgrade_proto.hpp
@@ -59,6 +59,18 @@ bool UpgradeV1LayerParameter(const V1LayerParameter& v1_layer_param,
const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type);
+// Return true iff the solver contains any old solver_type specified as enums
+bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param);
+
+bool UpgradeSolverType(SolverParameter* solver_param);
+
+// Check for deprecations and upgrade the SolverParameter as needed.
+bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param);
+
+// Read parameters from a file into a SolverParameter proto message.
+void ReadSolverParamsFromTextFileOrDie(const string& param_file,
+ SolverParameter* param);
+
} // namespace caffe
#endif // CAFFE_UTIL_UPGRADE_PROTO_H_
diff --git a/matlab/+caffe/private/caffe_.cpp b/matlab/+caffe/private/caffe_.cpp
index 7883f79e..1641e14b 100644
--- a/matlab/+caffe/private/caffe_.cpp
+++ b/matlab/+caffe/private/caffe_.cpp
@@ -188,7 +188,10 @@ static void get_solver(MEX_ARGS) {
"Usage: caffe_('get_solver', solver_file)");
char* solver_file = mxArrayToString(prhs[0]);
mxCHECK_FILE_EXIST(solver_file);
- shared_ptr<Solver<float> > solver(new caffe::SGDSolver<float>(solver_file));
+ SolverParameter solver_param;
+ ReadSolverParamsFromTextFileOrDie(solver_file, &solver_param);
+ shared_ptr<Solver<float> > solver(
+ SolverRegistry<float>::CreateSolver(solver_param));
solvers_.push_back(solver);
plhs[0] = ptr_to_handle<Solver<float> >(solver.get());
mxFree(solver_file);
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp
index 0e38dee7..8687dd87 100644
--- a/python/caffe/_caffe.cpp
+++ b/python/caffe/_caffe.cpp
@@ -134,8 +134,8 @@ void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
Solver<Dtype>* GetSolverFromFile(const string& filename) {
SolverParameter param;
- ReadProtoFromTextFileOrDie(filename, &param);
- return GetSolver<Dtype>(param);
+ ReadSolverParamsFromTextFileOrDie(filename, &param);
+ return SolverRegistry<Dtype>::CreateSolver(param);
}
struct NdarrayConverterGenerator {
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 016a0288..d3bc7361 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -36,7 +36,7 @@ Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
: net_(), callbacks_(), root_solver_(root_solver),
requested_early_exit_(false) {
SolverParameter param;
- ReadProtoFromTextFileOrDie(param_file, &param);
+ ReadSolverParamsFromTextFileOrDie(param_file, &param);
Init(param);
}
diff --git a/src/caffe/test/test_upgrade_proto.cpp b/src/caffe/test/test_upgrade_proto.cpp
index ee05b151..df9aeb62 100644
--- a/src/caffe/test/test_upgrade_proto.cpp
+++ b/src/caffe/test/test_upgrade_proto.cpp
@@ -2928,4 +2928,65 @@ TEST_F(NetUpgradeTest, TestUpgradeV1LayerType) {
}
}
#endif // USE_OPENCV
+
+class SolverTypeUpgradeTest : public ::testing::Test {
+ protected:
+ void RunSolverTypeUpgradeTest(
+ const string& input_param_string, const string& output_param_string) {
+ // Test upgrading old solver_type field (enum) to new type field (string)
+ SolverParameter input_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ input_param_string, &input_param));
+ SolverParameter expected_output_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ output_param_string, &expected_output_param));
+ SolverParameter actual_output_param = input_param;
+ UpgradeSolverType(&actual_output_param);
+ EXPECT_EQ(expected_output_param.DebugString(),
+ actual_output_param.DebugString());
+ }
+};
+
+TEST_F(SolverTypeUpgradeTest, TestSimple) {
+ const char* old_type_vec[6] = { "SGD", "ADAGRAD", "NESTEROV", "RMSPROP",
+ "ADADELTA", "ADAM" };
+ const char* new_type_vec[6] = { "SGD", "AdaGrad", "Nesterov", "RMSProp",
+ "AdaDelta", "Adam" };
+ for (int i = 0; i < 6; ++i) {
+ const string& input_proto =
+ "net: 'examples/mnist/lenet_train_test.prototxt' "
+ "test_iter: 100 "
+ "test_interval: 500 "
+ "base_lr: 0.01 "
+ "momentum: 0.0 "
+ "weight_decay: 0.0005 "
+ "lr_policy: 'inv' "
+ "gamma: 0.0001 "
+ "power: 0.75 "
+ "display: 100 "
+ "max_iter: 10000 "
+ "snapshot: 5000 "
+ "snapshot_prefix: 'examples/mnist/lenet_rmsprop' "
+ "solver_mode: GPU "
+ "solver_type: " + std::string(old_type_vec[i]) + " ";
+ const string& expected_output_proto =
+ "net: 'examples/mnist/lenet_train_test.prototxt' "
+ "test_iter: 100 "
+ "test_interval: 500 "
+ "base_lr: 0.01 "
+ "momentum: 0.0 "
+ "weight_decay: 0.0005 "
+ "lr_policy: 'inv' "
+ "gamma: 0.0001 "
+ "power: 0.75 "
+ "display: 100 "
+ "max_iter: 10000 "
+ "snapshot: 5000 "
+ "snapshot_prefix: 'examples/mnist/lenet_rmsprop' "
+ "solver_mode: GPU "
+ "type: '" + std::string(new_type_vec[i]) + "' ";
+ this->RunSolverTypeUpgradeTest(input_proto, expected_output_proto);
+ }
+}
+
} // NOLINT(readability/fn_size) // namespace caffe
diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp
index 6eae9fec..ff3f8ffc 100644
--- a/src/caffe/util/upgrade_proto.cpp
+++ b/src/caffe/util/upgrade_proto.cpp
@@ -937,4 +937,78 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) {
}
}
+// Return true iff the solver contains any old solver_type specified as enums
+bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param) {
+ if (solver_param.has_solver_type()) {
+ return true;
+ }
+ return false;
+}
+
+bool UpgradeSolverType(SolverParameter* solver_param) {
+ CHECK(!solver_param->has_solver_type() || !solver_param->has_type())
+ << "Failed to upgrade solver: old solver_type field (enum) and new type "
+ << "field (string) cannot be both specified in solver proto text.";
+ if (solver_param->has_solver_type()) {
+ string type;
+ switch (solver_param->solver_type()) {
+ case SolverParameter_SolverType_SGD:
+ type = "SGD";
+ break;
+ case SolverParameter_SolverType_NESTEROV:
+ type = "Nesterov";
+ break;
+ case SolverParameter_SolverType_ADAGRAD:
+ type = "AdaGrad";
+ break;
+ case SolverParameter_SolverType_RMSPROP:
+ type = "RMSProp";
+ break;
+ case SolverParameter_SolverType_ADADELTA:
+ type = "AdaDelta";
+ break;
+ case SolverParameter_SolverType_ADAM:
+ type = "Adam";
+ break;
+ default:
+ LOG(FATAL) << "Unknown SolverParameter solver_type: " << type;
+ }
+ solver_param->set_type(type);
+ solver_param->clear_solver_type();
+ } else {
+ LOG(ERROR) << "Warning: solver type already up to date. ";
+ return false;
+ }
+ return true;
+}
+
+// Check for deprecations and upgrade the SolverParameter as needed.
+bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) {
+ bool success = true;
+ // Try to upgrade old style solver_type enum fields into new string type
+ if (SolverNeedsTypeUpgrade(*param)) {
+ LOG(INFO) << "Attempting to upgrade input file specified using deprecated "
+ << "'solver_type' field (enum)': " << param_file;
+ if (!UpgradeSolverType(param)) {
+ success = false;
+ LOG(ERROR) << "Warning: had one or more problems upgrading "
+ << "SolverType (see above).";
+ } else {
+ LOG(INFO) << "Successfully upgraded file specified using deprecated "
+ << "'solver_type' field (enum) to 'type' field (string).";
+ LOG(WARNING) << "Note that future Caffe releases will only support "
+ << "'type' field (string) for a solver's type.";
+ }
+ }
+ return success;
+}
+
+// Read parameters from a file into a SolverParameter proto message.
+void ReadSolverParamsFromTextFileOrDie(const string& param_file,
+ SolverParameter* param) {
+ CHECK(ReadProtoFromTextFile(param_file, param))
+ << "Failed to parse SolverParameter file: " << param_file;
+ UpgradeSolverAsNeeded(param_file, param);
+}
+
} // namespace caffe
diff --git a/tools/caffe.cpp b/tools/caffe.cpp
index 1cb6ad89..305cfc36 100644
--- a/tools/caffe.cpp
+++ b/tools/caffe.cpp
@@ -157,7 +157,7 @@ int train() {
"but not both.";
caffe::SolverParameter solver_param;
- caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param);
+ caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);
// If the gpus flag is not provided, allow the mode and device to be set
// in the solver prototxt.
diff --git a/tools/upgrade_solver_proto_text.cpp b/tools/upgrade_solver_proto_text.cpp
new file mode 100644
index 00000000..7130232a
--- /dev/null
+++ b/tools/upgrade_solver_proto_text.cpp
@@ -0,0 +1,50 @@
+// This is a script to upgrade old solver prototxts to the new format.
+// Usage:
+// upgrade_solver_proto_text old_solver_proto_file_in solver_proto_file_out
+
+#include <cstring>
+#include <fstream> // NOLINT(readability/streams)
+#include <iostream> // NOLINT(readability/streams)
+#include <string>
+
+#include "caffe/caffe.hpp"
+#include "caffe/util/io.hpp"
+#include "caffe/util/upgrade_proto.hpp"
+
+using std::ofstream;
+
+using namespace caffe; // NOLINT(build/namespaces)
+
+int main(int argc, char** argv) {
+ ::google::InitGoogleLogging(argv[0]);
+ if (argc != 3) {
+ LOG(ERROR) << "Usage: upgrade_solver_proto_text "
+ << "old_solver_proto_file_in solver_proto_file_out";
+ return 1;
+ }
+
+ SolverParameter solver_param;
+ string input_filename(argv[1]);
+ if (!ReadProtoFromTextFile(input_filename, &solver_param)) {
+ LOG(ERROR) << "Failed to parse input text file as SolverParameter: "
+ << input_filename;
+ return 2;
+ }
+ bool need_upgrade = SolverNeedsTypeUpgrade(solver_param);
+ bool success = true;
+ if (need_upgrade) {
+ success = UpgradeSolverAsNeeded(input_filename, &solver_param);
+ if (!success) {
+ LOG(ERROR) << "Encountered error(s) while upgrading prototxt; "
+ << "see details above.";
+ }
+ } else {
+ LOG(ERROR) << "File already in latest proto format: " << input_filename;
+ }
+
+ // Save new format prototxt.
+ WriteProtoToTextFile(solver_param, argv[2]);
+
+ LOG(ERROR) << "Wrote upgraded SolverParameter text proto to " << argv[2];
+ return !success;
+}