From c98ed3b8b2da44182551ed8c996c7102b9fec113 Mon Sep 17 00:00:00 2001 From: Mohamed Omran Date: Fri, 29 Aug 2014 17:16:41 +0200 Subject: minor changes to variable names and error messages + set default backed in convert_mnist_data.cpp to lmdb --- examples/mnist/convert_mnist_data.cpp | 45 ++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 19 deletions(-) (limited to 'examples/mnist') diff --git a/examples/mnist/convert_mnist_data.cpp b/examples/mnist/convert_mnist_data.cpp index 6fbd2943..0d3e91ec 100644 --- a/examples/mnist/convert_mnist_data.cpp +++ b/examples/mnist/convert_mnist_data.cpp @@ -2,7 +2,8 @@ // This script converts the MNIST dataset to the leveldb format used // by caffe to perform classification. // Usage: -// convert_mnist_data input_image_file input_label_file output_db_file +// convert_mnist_data [FLAGS] input_image_file input_label_file +// output_db_file // The MNIST dataset could be downloaded at // http://yann.lecun.com/exdb/mnist/ @@ -20,7 +21,10 @@ #include "caffe/proto/caffe.pb.h" -DEFINE_string(backend, "leveldb", "The backend for storing the result"); +using namespace caffe; // NOLINT(build/namespaces) +using std::string; + +DEFINE_string(backend, "lmdb", "The backend for storing the result"); uint32_t swap_endian(uint32_t val) { val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF); @@ -28,7 +32,7 @@ uint32_t swap_endian(uint32_t val) { } void convert_dataset(const char* image_filename, const char* label_filename, - const char* db_filename, const std::string& db_backend) { + const char* db_path, const string& db_backend) { // Open files std::ifstream image_file(image_filename, std::ios::in | std::ios::binary); std::ifstream label_file(label_filename, std::ios::in | std::ios::binary); @@ -65,55 +69,58 @@ void convert_dataset(const char* image_filename, const char* label_filename, // leveldb leveldb::DB* db; leveldb::Options options; - options.create_if_missing = true; options.error_if_exists = true; + options.create_if_missing = true; options.write_buffer_size = 268435456; leveldb::WriteBatch* batch = NULL; - // Open new db + // Open db if (db_backend == "leveldb") { // leveldb + LOG(INFO) << "Opening leveldb " << db_path; leveldb::Status status = leveldb::DB::Open( - options, db_filename, &db); - CHECK(status.ok()) << "Failed to open leveldb " << db_filename + options, db_path, &db); + CHECK(status.ok()) << "Failed to open leveldb " << db_path << ". Is it already existing?"; batch = new leveldb::WriteBatch(); } else if (db_backend == "lmdb") { // lmdb - CHECK_EQ(mkdir(db_filename, 0744), 0) - << "mkdir " << db_filename << "failed"; + LOG(INFO) << "Opening lmdb " << db_path; + CHECK_EQ(mkdir(db_path, 0744), 0) + << "mkdir " << db_path << "failed"; CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed"; CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS) // 1TB << "mdb_env_set_mapsize failed"; - CHECK_EQ(mdb_env_open(mdb_env, db_filename, 0, 0664), MDB_SUCCESS) + CHECK_EQ(mdb_env_open(mdb_env, db_path, 0, 0664), MDB_SUCCESS) << "mdb_env_open failed"; CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS) << "mdb_txn_begin failed"; CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS) - << "mdb_open failed"; + << "mdb_open failed. Does the lmdb already exist? "; } else { LOG(FATAL) << "Unknown db backend " << db_backend; } + // Storing to db char label; char* pixels = new char[rows * cols]; - const int kMaxKeyLength = 10; - char key[kMaxKeyLength]; - std::string value; int count = 0; + const int kMaxKeyLength = 10; + char key_cstr[kMaxKeyLength]; + string value; - caffe::Datum datum; + Datum datum; datum.set_channels(1); datum.set_height(rows); datum.set_width(cols); LOG(INFO) << "A total of " << num_items << " items."; LOG(INFO) << "Rows: " << rows << " Cols: " << cols; - for (int itemid = 0; itemid < num_items; ++itemid) { + for (int item_id = 0; item_id < num_items; ++item_id) { image_file.read(pixels, rows * cols); label_file.read(&label, 1); datum.set_data(pixels, rows*cols); datum.set_label(label); + snprintf(key_cstr, kMaxKeyLength, "%08d", item_id); datum.SerializeToString(&value); - snprintf(key, kMaxKeyLength, "%08d", itemid); - std::string keystr(key); + string keystr(key_cstr); // Put in db if (db_backend == "leveldb") { // leveldb @@ -179,7 +186,7 @@ int main(int argc, char** argv) { "or directly use data/mnist/get_mnist.sh\n"); gflags::ParseCommandLineFlags(&argc, &argv, true); - const std::string& db_backend = FLAGS_backend; + const string& db_backend = FLAGS_backend; if (argc != 4) { gflags::ShowUsageWithFlagsRestrict(argv[0], -- cgit v1.2.3