diff options
author | Kevin James Matzen <kmatzen@cs.cornell.edu> | 2014-10-07 21:46:15 -0400 |
---|---|---|
committer | Kevin James Matzen <kmatzen@cs.cornell.edu> | 2014-10-14 19:29:11 -0400 |
commit | edff676ea8482d48f4b7bc794288696afce782ca (patch) | |
tree | f8b7bae22c1331d10cf36deee2a6b168671b813a /examples | |
parent | 7e504c0612b23449aa83fd7aac8d49c56b03fd62 (diff) | |
download | caffeonacl-edff676ea8482d48f4b7bc794288696afce782ca.tar.gz caffeonacl-edff676ea8482d48f4b7bc794288696afce782ca.tar.bz2 caffeonacl-edff676ea8482d48f4b7bc794288696afce782ca.zip |
Don't autocommit on close for the databases. If they were read-only, then they might fail.
Diffstat (limited to 'examples')
-rw-r--r-- | examples/cifar10/convert_cifar_data.cpp | 59 |
1 files changed, 32 insertions, 27 deletions
diff --git a/examples/cifar10/convert_cifar_data.cpp b/examples/cifar10/convert_cifar_data.cpp index 90ecb6d9..c4930878 100644 --- a/examples/cifar10/convert_cifar_data.cpp +++ b/examples/cifar10/convert_cifar_data.cpp @@ -11,13 +11,17 @@ #include "glog/logging.h" #include "google/protobuf/text_format.h" -#include "leveldb/db.h" #include "stdint.h" +#include "caffe/database_factory.hpp" #include "caffe/proto/caffe.pb.h" using std::string; +using caffe::Database; +using caffe::DatabaseFactory; +using caffe::shared_ptr; + const int kCIFARSize = 32; const int kCIFARImageNBytes = 3072; const int kCIFARBatchSize = 10000; @@ -31,26 +35,20 @@ void read_image(std::ifstream* file, int* label, char* buffer) { return; } -void convert_dataset(const string& input_folder, const string& output_folder) { - // Leveldb options - leveldb::Options options; - options.create_if_missing = true; - options.error_if_exists = true; +void convert_dataset(const string& input_folder, const string& output_folder, + const string& db_type) { + shared_ptr<Database> train_database = DatabaseFactory(db_type); + train_database->open(output_folder + "/cifar10_train_" + db_type, + Database::New); // Data buffer int label; char str_buffer[kCIFARImageNBytes]; - string value; caffe::Datum datum; datum.set_channels(3); datum.set_height(kCIFARSize); datum.set_width(kCIFARSize); LOG(INFO) << "Writing Training data"; - leveldb::DB* train_db; - leveldb::Status status; - status = leveldb::DB::Open(options, output_folder + "/cifar10_train_leveldb", - &train_db); - CHECK(status.ok()) << "Failed to open leveldb."; for (int fileid = 0; fileid < kCIFARTrainBatches; ++fileid) { // Open files LOG(INFO) << "Training Batch " << fileid + 1; @@ -62,17 +60,22 @@ void convert_dataset(const string& input_folder, const string& output_folder) { read_image(&data_file, &label, str_buffer); datum.set_label(label); datum.set_data(str_buffer, kCIFARImageNBytes); - datum.SerializeToString(&value); - snprintf(str_buffer, kCIFARImageNBytes, "%05d", + Database::buffer_t value(datum.ByteSize()); + datum.SerializeWithCachedSizesToArray( + reinterpret_cast<unsigned char*>(value.data())); + int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", fileid * kCIFARBatchSize + itemid); - train_db->Put(leveldb::WriteOptions(), string(str_buffer), value); + Database::buffer_t key(str_buffer, str_buffer + length); + train_database->put(&key, &value); } } + train_database->commit(); + train_database->close(); LOG(INFO) << "Writing Testing data"; - leveldb::DB* test_db; - CHECK(leveldb::DB::Open(options, output_folder + "/cifar10_test_leveldb", - &test_db).ok()) << "Failed to open leveldb."; + shared_ptr<Database> test_database = DatabaseFactory(db_type); + test_database->open(output_folder + "/cifar10_test_" + db_type, + Database::New); // Open files std::ifstream data_file((input_folder + "/test_batch.bin").c_str(), std::ios::in | std::ios::binary); @@ -81,28 +84,30 @@ void convert_dataset(const string& input_folder, const string& output_folder) { read_image(&data_file, &label, str_buffer); datum.set_label(label); datum.set_data(str_buffer, kCIFARImageNBytes); - datum.SerializeToString(&value); - snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid); - test_db->Put(leveldb::WriteOptions(), string(str_buffer), value); + Database::buffer_t value(datum.ByteSize()); + datum.SerializeWithCachedSizesToArray( + reinterpret_cast<unsigned char*>(value.data())); + int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid); + Database::buffer_t key(str_buffer, str_buffer + length); + test_database->put(&key, &value); } - - delete train_db; - delete test_db; + test_database->commit(); + test_database->close(); } int main(int argc, char** argv) { - if (argc != 3) { + if (argc != 4) { printf("This script converts the CIFAR dataset to the leveldb format used\n" "by caffe to perform classification.\n" "Usage:\n" - " convert_cifar_data input_folder output_folder\n" + " convert_cifar_data input_folder output_folder db_type\n" "Where the input folder should contain the binary batch files.\n" "The CIFAR dataset could be downloaded at\n" " http://www.cs.toronto.edu/~kriz/cifar.html\n" "You should gunzip them after downloading.\n"); } else { google::InitGoogleLogging(argv[0]); - convert_dataset(string(argv[1]), string(argv[2])); + convert_dataset(string(argv[1]), string(argv[2]), string(argv[3])); } return 0; } |