diff options
author | Ronghang Hu <huronghang@hotmail.com> | 2015-09-26 11:47:02 -0700 |
---|---|---|
committer | Ronghang Hu <huronghang@hotmail.com> | 2015-10-16 22:32:33 -0700 |
commit | c1f7fe1cffa4388886b735f49cd915fad905fca4 (patch) | |
tree | 53a3aa9ee7973670c6d13e932d036ecc91c141d6 | |
parent | 0eea815ad6fa3313888b6229499a237820258deb (diff) | |
download | caffeonacl-c1f7fe1cffa4388886b735f49cd915fad905fca4.tar.gz caffeonacl-c1f7fe1cffa4388886b735f49cd915fad905fca4.tar.bz2 caffeonacl-c1f7fe1cffa4388886b735f49cd915fad905fca4.zip |
Add automatic upgrade for solver type
-rw-r--r-- | include/caffe/caffe.hpp | 1 | ||||
-rw-r--r-- | include/caffe/util/upgrade_proto.hpp | 12 | ||||
-rw-r--r-- | matlab/+caffe/private/caffe_.cpp | 5 | ||||
-rw-r--r-- | python/caffe/_caffe.cpp | 4 | ||||
-rw-r--r-- | src/caffe/solver.cpp | 2 | ||||
-rw-r--r-- | src/caffe/test/test_upgrade_proto.cpp | 61 | ||||
-rw-r--r-- | src/caffe/util/upgrade_proto.cpp | 74 | ||||
-rw-r--r-- | tools/caffe.cpp | 2 | ||||
-rw-r--r-- | tools/upgrade_solver_proto_text.cpp | 50 |
9 files changed, 206 insertions, 5 deletions
diff --git a/include/caffe/caffe.hpp b/include/caffe/caffe.hpp index bd772830..a339efba 100644 --- a/include/caffe/caffe.hpp +++ b/include/caffe/caffe.hpp @@ -16,6 +16,7 @@ #include "caffe/solver_factory.hpp" #include "caffe/util/benchmark.hpp" #include "caffe/util/io.hpp" +#include "caffe/util/upgrade_proto.hpp" #include "caffe/vision_layers.hpp" #endif // CAFFE_CAFFE_HPP_ diff --git a/include/caffe/util/upgrade_proto.hpp b/include/caffe/util/upgrade_proto.hpp index 6a141843..c94bb3ca 100644 --- a/include/caffe/util/upgrade_proto.hpp +++ b/include/caffe/util/upgrade_proto.hpp @@ -59,6 +59,18 @@ bool UpgradeV1LayerParameter(const V1LayerParameter& v1_layer_param, const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type); +// Return true iff the solver contains any old solver_type specified as enums +bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param); + +bool UpgradeSolverType(SolverParameter* solver_param); + +// Check for deprecations and upgrade the SolverParameter as needed. +bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param); + +// Read parameters from a file into a SolverParameter proto message. +void ReadSolverParamsFromTextFileOrDie(const string& param_file, + SolverParameter* param); + } // namespace caffe #endif // CAFFE_UTIL_UPGRADE_PROTO_H_ diff --git a/matlab/+caffe/private/caffe_.cpp b/matlab/+caffe/private/caffe_.cpp index 7883f79e..1641e14b 100644 --- a/matlab/+caffe/private/caffe_.cpp +++ b/matlab/+caffe/private/caffe_.cpp @@ -188,7 +188,10 @@ static void get_solver(MEX_ARGS) { "Usage: caffe_('get_solver', solver_file)"); char* solver_file = mxArrayToString(prhs[0]); mxCHECK_FILE_EXIST(solver_file); - shared_ptr<Solver<float> > solver(new caffe::SGDSolver<float>(solver_file)); + SolverParameter solver_param; + ReadSolverParamsFromTextFileOrDie(solver_file, &solver_param); + shared_ptr<Solver<float> > solver( + SolverRegistry<float>::CreateSolver(solver_param)); solvers_.push_back(solver); plhs[0] = ptr_to_handle<Solver<float> >(solver.get()); mxFree(solver_file); diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index 0e38dee7..8687dd87 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -134,8 +134,8 @@ void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj, Solver<Dtype>* GetSolverFromFile(const string& filename) { SolverParameter param; - ReadProtoFromTextFileOrDie(filename, ¶m); - return GetSolver<Dtype>(param); + ReadSolverParamsFromTextFileOrDie(filename, ¶m); + return SolverRegistry<Dtype>::CreateSolver(param); } struct NdarrayConverterGenerator { diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 016a0288..d3bc7361 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -36,7 +36,7 @@ Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver) : net_(), callbacks_(), root_solver_(root_solver), requested_early_exit_(false) { SolverParameter param; - ReadProtoFromTextFileOrDie(param_file, ¶m); + ReadSolverParamsFromTextFileOrDie(param_file, ¶m); Init(param); } diff --git a/src/caffe/test/test_upgrade_proto.cpp b/src/caffe/test/test_upgrade_proto.cpp index ee05b151..df9aeb62 100644 --- a/src/caffe/test/test_upgrade_proto.cpp +++ b/src/caffe/test/test_upgrade_proto.cpp @@ -2928,4 +2928,65 @@ TEST_F(NetUpgradeTest, TestUpgradeV1LayerType) { } } #endif // USE_OPENCV + +class SolverTypeUpgradeTest : public ::testing::Test { + protected: + void RunSolverTypeUpgradeTest( + const string& input_param_string, const string& output_param_string) { + // Test upgrading old solver_type field (enum) to new type field (string) + SolverParameter input_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + input_param_string, &input_param)); + SolverParameter expected_output_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + output_param_string, &expected_output_param)); + SolverParameter actual_output_param = input_param; + UpgradeSolverType(&actual_output_param); + EXPECT_EQ(expected_output_param.DebugString(), + actual_output_param.DebugString()); + } +}; + +TEST_F(SolverTypeUpgradeTest, TestSimple) { + const char* old_type_vec[6] = { "SGD", "ADAGRAD", "NESTEROV", "RMSPROP", + "ADADELTA", "ADAM" }; + const char* new_type_vec[6] = { "SGD", "AdaGrad", "Nesterov", "RMSProp", + "AdaDelta", "Adam" }; + for (int i = 0; i < 6; ++i) { + const string& input_proto = + "net: 'examples/mnist/lenet_train_test.prototxt' " + "test_iter: 100 " + "test_interval: 500 " + "base_lr: 0.01 " + "momentum: 0.0 " + "weight_decay: 0.0005 " + "lr_policy: 'inv' " + "gamma: 0.0001 " + "power: 0.75 " + "display: 100 " + "max_iter: 10000 " + "snapshot: 5000 " + "snapshot_prefix: 'examples/mnist/lenet_rmsprop' " + "solver_mode: GPU " + "solver_type: " + std::string(old_type_vec[i]) + " "; + const string& expected_output_proto = + "net: 'examples/mnist/lenet_train_test.prototxt' " + "test_iter: 100 " + "test_interval: 500 " + "base_lr: 0.01 " + "momentum: 0.0 " + "weight_decay: 0.0005 " + "lr_policy: 'inv' " + "gamma: 0.0001 " + "power: 0.75 " + "display: 100 " + "max_iter: 10000 " + "snapshot: 5000 " + "snapshot_prefix: 'examples/mnist/lenet_rmsprop' " + "solver_mode: GPU " + "type: '" + std::string(new_type_vec[i]) + "' "; + this->RunSolverTypeUpgradeTest(input_proto, expected_output_proto); + } +} + } // NOLINT(readability/fn_size) // namespace caffe diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index 6eae9fec..ff3f8ffc 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -937,4 +937,78 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) { } } +// Return true iff the solver contains any old solver_type specified as enums +bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param) { + if (solver_param.has_solver_type()) { + return true; + } + return false; +} + +bool UpgradeSolverType(SolverParameter* solver_param) { + CHECK(!solver_param->has_solver_type() || !solver_param->has_type()) + << "Failed to upgrade solver: old solver_type field (enum) and new type " + << "field (string) cannot be both specified in solver proto text."; + if (solver_param->has_solver_type()) { + string type; + switch (solver_param->solver_type()) { + case SolverParameter_SolverType_SGD: + type = "SGD"; + break; + case SolverParameter_SolverType_NESTEROV: + type = "Nesterov"; + break; + case SolverParameter_SolverType_ADAGRAD: + type = "AdaGrad"; + break; + case SolverParameter_SolverType_RMSPROP: + type = "RMSProp"; + break; + case SolverParameter_SolverType_ADADELTA: + type = "AdaDelta"; + break; + case SolverParameter_SolverType_ADAM: + type = "Adam"; + break; + default: + LOG(FATAL) << "Unknown SolverParameter solver_type: " << type; + } + solver_param->set_type(type); + solver_param->clear_solver_type(); + } else { + LOG(ERROR) << "Warning: solver type already up to date. "; + return false; + } + return true; +} + +// Check for deprecations and upgrade the SolverParameter as needed. +bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) { + bool success = true; + // Try to upgrade old style solver_type enum fields into new string type + if (SolverNeedsTypeUpgrade(*param)) { + LOG(INFO) << "Attempting to upgrade input file specified using deprecated " + << "'solver_type' field (enum)': " << param_file; + if (!UpgradeSolverType(param)) { + success = false; + LOG(ERROR) << "Warning: had one or more problems upgrading " + << "SolverType (see above)."; + } else { + LOG(INFO) << "Successfully upgraded file specified using deprecated " + << "'solver_type' field (enum) to 'type' field (string)."; + LOG(WARNING) << "Note that future Caffe releases will only support " + << "'type' field (string) for a solver's type."; + } + } + return success; +} + +// Read parameters from a file into a SolverParameter proto message. +void ReadSolverParamsFromTextFileOrDie(const string& param_file, + SolverParameter* param) { + CHECK(ReadProtoFromTextFile(param_file, param)) + << "Failed to parse SolverParameter file: " << param_file; + UpgradeSolverAsNeeded(param_file, param); +} + } // namespace caffe diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 1cb6ad89..305cfc36 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -157,7 +157,7 @@ int train() { "but not both."; caffe::SolverParameter solver_param; - caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param); + caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param); // If the gpus flag is not provided, allow the mode and device to be set // in the solver prototxt. diff --git a/tools/upgrade_solver_proto_text.cpp b/tools/upgrade_solver_proto_text.cpp new file mode 100644 index 00000000..7130232a --- /dev/null +++ b/tools/upgrade_solver_proto_text.cpp @@ -0,0 +1,50 @@ +// This is a script to upgrade old solver prototxts to the new format. +// Usage: +// upgrade_solver_proto_text old_solver_proto_file_in solver_proto_file_out + +#include <cstring> +#include <fstream> // NOLINT(readability/streams) +#include <iostream> // NOLINT(readability/streams) +#include <string> + +#include "caffe/caffe.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/upgrade_proto.hpp" + +using std::ofstream; + +using namespace caffe; // NOLINT(build/namespaces) + +int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + if (argc != 3) { + LOG(ERROR) << "Usage: upgrade_solver_proto_text " + << "old_solver_proto_file_in solver_proto_file_out"; + return 1; + } + + SolverParameter solver_param; + string input_filename(argv[1]); + if (!ReadProtoFromTextFile(input_filename, &solver_param)) { + LOG(ERROR) << "Failed to parse input text file as SolverParameter: " + << input_filename; + return 2; + } + bool need_upgrade = SolverNeedsTypeUpgrade(solver_param); + bool success = true; + if (need_upgrade) { + success = UpgradeSolverAsNeeded(input_filename, &solver_param); + if (!success) { + LOG(ERROR) << "Encountered error(s) while upgrading prototxt; " + << "see details above."; + } + } else { + LOG(ERROR) << "File already in latest proto format: " << input_filename; + } + + // Save new format prototxt. + WriteProtoToTextFile(solver_param, argv[2]); + + LOG(ERROR) << "Wrote upgraded SolverParameter text proto to " << argv[2]; + return !success; +} |