summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorKevin James Matzen <kmatzen@cs.cornell.edu>2014-10-13 13:16:04 -0400
committerKevin James Matzen <kmatzen@cs.cornell.edu>2014-10-14 19:35:23 -0400
commit08b971feae8551ca5c7ce31a938a1f232ee56af2 (patch)
treeb8569b89b09583b4434b0d700ee2943f24b37075 /examples
parent0987c72c9d11b94524e1dc91150daf273c5b2537 (diff)
downloadcaffeonacl-08b971feae8551ca5c7ce31a938a1f232ee56af2.tar.gz
caffeonacl-08b971feae8551ca5c7ce31a938a1f232ee56af2.tar.bz2
caffeonacl-08b971feae8551ca5c7ce31a938a1f232ee56af2.zip
Templated the key and value types for the Database interface. The Database is now responsible for serialization. Refactored the tests so that they reuse the same code for each value type and backend configuration.
Diffstat (limited to 'examples')
-rw-r--r--examples/cifar10/convert_cifar_data.cpp25
1 files changed, 10 insertions, 15 deletions
diff --git a/examples/cifar10/convert_cifar_data.cpp b/examples/cifar10/convert_cifar_data.cpp
index 46ab1558..f86f3936 100644
--- a/examples/cifar10/convert_cifar_data.cpp
+++ b/examples/cifar10/convert_cifar_data.cpp
@@ -20,6 +20,7 @@ using std::string;
using caffe::Database;
using caffe::DatabaseFactory;
+using caffe::Datum;
using caffe::shared_ptr;
const int kCIFARSize = 32;
@@ -37,13 +38,14 @@ void read_image(std::ifstream* file, int* label, char* buffer) {
void convert_dataset(const string& input_folder, const string& output_folder,
const string& db_type) {
- shared_ptr<Database> train_database = DatabaseFactory(db_type);
+ shared_ptr<Database<string, Datum> > train_database =
+ DatabaseFactory<string, Datum>(db_type);
CHECK(train_database->open(output_folder + "/cifar10_train_" + db_type,
- Database::New));
+ Database<string, Datum>::New));
// Data buffer
int label;
char str_buffer[kCIFARImageNBytes];
- caffe::Datum datum;
+ Datum datum;
datum.set_channels(3);
datum.set_height(kCIFARSize);
datum.set_width(kCIFARSize);
@@ -60,22 +62,19 @@ void convert_dataset(const string& input_folder, const string& output_folder,
read_image(&data_file, &label, str_buffer);
datum.set_label(label);
datum.set_data(str_buffer, kCIFARImageNBytes);
- Database::value_type value(datum.ByteSize());
- datum.SerializeWithCachedSizesToArray(
- reinterpret_cast<unsigned char*>(value.data()));
int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d",
fileid * kCIFARBatchSize + itemid);
- Database::key_type key(str_buffer, str_buffer + length);
- CHECK(train_database->put(key, value));
+ CHECK(train_database->put(string(str_buffer, length), datum));
}
}
CHECK(train_database->commit());
train_database->close();
LOG(INFO) << "Writing Testing data";
- shared_ptr<Database> test_database = DatabaseFactory(db_type);
+ shared_ptr<Database<string, Datum> > test_database =
+ DatabaseFactory<string, Datum>(db_type);
CHECK(test_database->open(output_folder + "/cifar10_test_" + db_type,
- Database::New));
+ Database<string, Datum>::New));
// Open files
std::ifstream data_file((input_folder + "/test_batch.bin").c_str(),
std::ios::in | std::ios::binary);
@@ -84,12 +83,8 @@ void convert_dataset(const string& input_folder, const string& output_folder,
read_image(&data_file, &label, str_buffer);
datum.set_label(label);
datum.set_data(str_buffer, kCIFARImageNBytes);
- Database::value_type value(datum.ByteSize());
- datum.SerializeWithCachedSizesToArray(
- reinterpret_cast<unsigned char*>(value.data()));
int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid);
- Database::key_type key(str_buffer, str_buffer + length);
- CHECK(test_database->put(key, value));
+ CHECK(test_database->put(string(str_buffer, length), datum));
}
CHECK(test_database->commit());
test_database->close();