diff options
author | Kevin James Matzen <kmatzen@cs.cornell.edu> | 2014-10-13 13:16:04 -0400 |
---|---|---|
committer | Kevin James Matzen <kmatzen@cs.cornell.edu> | 2014-10-14 19:35:23 -0400 |
commit | 08b971feae8551ca5c7ce31a938a1f232ee56af2 (patch) | |
tree | b8569b89b09583b4434b0d700ee2943f24b37075 /examples | |
parent | 0987c72c9d11b94524e1dc91150daf273c5b2537 (diff) | |
download | caffeonacl-08b971feae8551ca5c7ce31a938a1f232ee56af2.tar.gz caffeonacl-08b971feae8551ca5c7ce31a938a1f232ee56af2.tar.bz2 caffeonacl-08b971feae8551ca5c7ce31a938a1f232ee56af2.zip |
Templated the key and value types for the Database interface. The Database is now responsible for serialization. Refactored the tests so that they reuse the same code for each value type and backend configuration.
Diffstat (limited to 'examples')
-rw-r--r-- | examples/cifar10/convert_cifar_data.cpp | 25 |
1 files changed, 10 insertions, 15 deletions
diff --git a/examples/cifar10/convert_cifar_data.cpp b/examples/cifar10/convert_cifar_data.cpp index 46ab1558..f86f3936 100644 --- a/examples/cifar10/convert_cifar_data.cpp +++ b/examples/cifar10/convert_cifar_data.cpp @@ -20,6 +20,7 @@ using std::string; using caffe::Database; using caffe::DatabaseFactory; +using caffe::Datum; using caffe::shared_ptr; const int kCIFARSize = 32; @@ -37,13 +38,14 @@ void read_image(std::ifstream* file, int* label, char* buffer) { void convert_dataset(const string& input_folder, const string& output_folder, const string& db_type) { - shared_ptr<Database> train_database = DatabaseFactory(db_type); + shared_ptr<Database<string, Datum> > train_database = + DatabaseFactory<string, Datum>(db_type); CHECK(train_database->open(output_folder + "/cifar10_train_" + db_type, - Database::New)); + Database<string, Datum>::New)); // Data buffer int label; char str_buffer[kCIFARImageNBytes]; - caffe::Datum datum; + Datum datum; datum.set_channels(3); datum.set_height(kCIFARSize); datum.set_width(kCIFARSize); @@ -60,22 +62,19 @@ void convert_dataset(const string& input_folder, const string& output_folder, read_image(&data_file, &label, str_buffer); datum.set_label(label); datum.set_data(str_buffer, kCIFARImageNBytes); - Database::value_type value(datum.ByteSize()); - datum.SerializeWithCachedSizesToArray( - reinterpret_cast<unsigned char*>(value.data())); int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", fileid * kCIFARBatchSize + itemid); - Database::key_type key(str_buffer, str_buffer + length); - CHECK(train_database->put(key, value)); + CHECK(train_database->put(string(str_buffer, length), datum)); } } CHECK(train_database->commit()); train_database->close(); LOG(INFO) << "Writing Testing data"; - shared_ptr<Database> test_database = DatabaseFactory(db_type); + shared_ptr<Database<string, Datum> > test_database = + DatabaseFactory<string, Datum>(db_type); CHECK(test_database->open(output_folder + "/cifar10_test_" + db_type, - Database::New)); + Database<string, Datum>::New)); // Open files std::ifstream data_file((input_folder + "/test_batch.bin").c_str(), std::ios::in | std::ios::binary); @@ -84,12 +83,8 @@ void convert_dataset(const string& input_folder, const string& output_folder, read_image(&data_file, &label, str_buffer); datum.set_label(label); datum.set_data(str_buffer, kCIFARImageNBytes); - Database::value_type value(datum.ByteSize()); - datum.SerializeWithCachedSizesToArray( - reinterpret_cast<unsigned char*>(value.data())); int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid); - Database::key_type key(str_buffer, str_buffer + length); - CHECK(test_database->put(key, value)); + CHECK(test_database->put(string(str_buffer, length), datum)); } CHECK(test_database->commit()); test_database->close(); |