summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorKevin James Matzen <kmatzen@cs.cornell.edu>2014-10-07 21:46:15 -0400
committerKevin James Matzen <kmatzen@cs.cornell.edu>2014-10-14 19:29:11 -0400
commitedff676ea8482d48f4b7bc794288696afce782ca (patch)
treef8b7bae22c1331d10cf36deee2a6b168671b813a /examples
parent7e504c0612b23449aa83fd7aac8d49c56b03fd62 (diff)
downloadcaffeonacl-edff676ea8482d48f4b7bc794288696afce782ca.tar.gz
caffeonacl-edff676ea8482d48f4b7bc794288696afce782ca.tar.bz2
caffeonacl-edff676ea8482d48f4b7bc794288696afce782ca.zip
Don't autocommit on close for the databases. If they were read-only, then they might fail.
Diffstat (limited to 'examples')
-rw-r--r--examples/cifar10/convert_cifar_data.cpp59
1 files changed, 32 insertions, 27 deletions
diff --git a/examples/cifar10/convert_cifar_data.cpp b/examples/cifar10/convert_cifar_data.cpp
index 90ecb6d9..c4930878 100644
--- a/examples/cifar10/convert_cifar_data.cpp
+++ b/examples/cifar10/convert_cifar_data.cpp
@@ -11,13 +11,17 @@
#include "glog/logging.h"
#include "google/protobuf/text_format.h"
-#include "leveldb/db.h"
#include "stdint.h"
+#include "caffe/database_factory.hpp"
#include "caffe/proto/caffe.pb.h"
using std::string;
+using caffe::Database;
+using caffe::DatabaseFactory;
+using caffe::shared_ptr;
+
const int kCIFARSize = 32;
const int kCIFARImageNBytes = 3072;
const int kCIFARBatchSize = 10000;
@@ -31,26 +35,20 @@ void read_image(std::ifstream* file, int* label, char* buffer) {
return;
}
-void convert_dataset(const string& input_folder, const string& output_folder) {
- // Leveldb options
- leveldb::Options options;
- options.create_if_missing = true;
- options.error_if_exists = true;
+void convert_dataset(const string& input_folder, const string& output_folder,
+ const string& db_type) {
+ shared_ptr<Database> train_database = DatabaseFactory(db_type);
+ train_database->open(output_folder + "/cifar10_train_" + db_type,
+ Database::New);
// Data buffer
int label;
char str_buffer[kCIFARImageNBytes];
- string value;
caffe::Datum datum;
datum.set_channels(3);
datum.set_height(kCIFARSize);
datum.set_width(kCIFARSize);
LOG(INFO) << "Writing Training data";
- leveldb::DB* train_db;
- leveldb::Status status;
- status = leveldb::DB::Open(options, output_folder + "/cifar10_train_leveldb",
- &train_db);
- CHECK(status.ok()) << "Failed to open leveldb.";
for (int fileid = 0; fileid < kCIFARTrainBatches; ++fileid) {
// Open files
LOG(INFO) << "Training Batch " << fileid + 1;
@@ -62,17 +60,22 @@ void convert_dataset(const string& input_folder, const string& output_folder) {
read_image(&data_file, &label, str_buffer);
datum.set_label(label);
datum.set_data(str_buffer, kCIFARImageNBytes);
- datum.SerializeToString(&value);
- snprintf(str_buffer, kCIFARImageNBytes, "%05d",
+ Database::buffer_t value(datum.ByteSize());
+ datum.SerializeWithCachedSizesToArray(
+ reinterpret_cast<unsigned char*>(value.data()));
+ int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d",
fileid * kCIFARBatchSize + itemid);
- train_db->Put(leveldb::WriteOptions(), string(str_buffer), value);
+ Database::buffer_t key(str_buffer, str_buffer + length);
+ train_database->put(&key, &value);
}
}
+ train_database->commit();
+ train_database->close();
LOG(INFO) << "Writing Testing data";
- leveldb::DB* test_db;
- CHECK(leveldb::DB::Open(options, output_folder + "/cifar10_test_leveldb",
- &test_db).ok()) << "Failed to open leveldb.";
+ shared_ptr<Database> test_database = DatabaseFactory(db_type);
+ test_database->open(output_folder + "/cifar10_test_" + db_type,
+ Database::New);
// Open files
std::ifstream data_file((input_folder + "/test_batch.bin").c_str(),
std::ios::in | std::ios::binary);
@@ -81,28 +84,30 @@ void convert_dataset(const string& input_folder, const string& output_folder) {
read_image(&data_file, &label, str_buffer);
datum.set_label(label);
datum.set_data(str_buffer, kCIFARImageNBytes);
- datum.SerializeToString(&value);
- snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid);
- test_db->Put(leveldb::WriteOptions(), string(str_buffer), value);
+ Database::buffer_t value(datum.ByteSize());
+ datum.SerializeWithCachedSizesToArray(
+ reinterpret_cast<unsigned char*>(value.data()));
+ int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid);
+ Database::buffer_t key(str_buffer, str_buffer + length);
+ test_database->put(&key, &value);
}
-
- delete train_db;
- delete test_db;
+ test_database->commit();
+ test_database->close();
}
int main(int argc, char** argv) {
- if (argc != 3) {
+ if (argc != 4) {
printf("This script converts the CIFAR dataset to the leveldb format used\n"
"by caffe to perform classification.\n"
"Usage:\n"
- " convert_cifar_data input_folder output_folder\n"
+ " convert_cifar_data input_folder output_folder db_type\n"
"Where the input folder should contain the binary batch files.\n"
"The CIFAR dataset could be downloaded at\n"
" http://www.cs.toronto.edu/~kriz/cifar.html\n"
"You should gunzip them after downloading.\n");
} else {
google::InitGoogleLogging(argv[0]);
- convert_dataset(string(argv[1]), string(argv[2]));
+ convert_dataset(string(argv[1]), string(argv[2]), string(argv[3]));
}
return 0;
}