summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorKevin James Matzen <kmatzen@cs.cornell.edu>2014-10-13 13:16:04 -0400
committerKevin James Matzen <kmatzen@cs.cornell.edu>2014-10-14 19:35:23 -0400
commit08b971feae8551ca5c7ce31a938a1f232ee56af2 (patch)
treeb8569b89b09583b4434b0d700ee2943f24b37075 /tools
parent0987c72c9d11b94524e1dc91150daf273c5b2537 (diff)
downloadcaffeonacl-08b971feae8551ca5c7ce31a938a1f232ee56af2.tar.gz
caffeonacl-08b971feae8551ca5c7ce31a938a1f232ee56af2.tar.bz2
caffeonacl-08b971feae8551ca5c7ce31a938a1f232ee56af2.zip
Templated the key and value types for the Database interface. The Database is now responsible for serialization. Refactored the tests so that they reuse the same code for each value type and backend configuration.
Diffstat (limited to 'tools')
-rw-r--r--tools/compute_image_mean.cpp16
-rw-r--r--tools/convert_imageset.cpp11
-rw-r--r--tools/extract_features.cpp14
3 files changed, 17 insertions, 24 deletions
diff --git a/tools/compute_image_mean.cpp b/tools/compute_image_mean.cpp
index d13c4a0f..f1a79679 100644
--- a/tools/compute_image_mean.cpp
+++ b/tools/compute_image_mean.cpp
@@ -26,18 +26,17 @@ int main(int argc, char** argv) {
db_backend = std::string(argv[3]);
}
- caffe::shared_ptr<Database> database = caffe::DatabaseFactory(db_backend);
+ caffe::shared_ptr<Database<std::string, Datum> > database =
+ caffe::DatabaseFactory<std::string, Datum>(db_backend);
// Open db
- CHECK(database->open(argv[1], Database::ReadOnly));
+ CHECK(database->open(argv[1], Database<std::string, Datum>::ReadOnly));
- Datum datum;
BlobProto sum_blob;
int count = 0;
// load first datum
- Database::const_iterator iter = database->begin();
- const Database::value_type& first_blob = iter->value;
- datum.ParseFromArray(first_blob.data(), first_blob.size());
+ Database<std::string, Datum>::const_iterator iter = database->begin();
+ const Datum& datum = iter->value;
sum_blob.set_num(1);
sum_blob.set_channels(datum.channels());
@@ -50,11 +49,10 @@ int main(int argc, char** argv) {
sum_blob.add_data(0.);
}
LOG(INFO) << "Starting Iteration";
- for (Database::const_iterator iter = database->begin();
+ for (Database<std::string, Datum>::const_iterator iter = database->begin();
iter != database->end(); ++iter) {
// just a dummy operation
- const Database::value_type& blob = iter->value;
- datum.ParseFromArray(blob.data(), blob.size());
+ const Datum& datum = iter->value;
const std::string& data = datum.data();
size_in_datum = std::max<int>(datum.data().size(),
datum.float_data_size());
diff --git a/tools/convert_imageset.cpp b/tools/convert_imageset.cpp
index 1cdca7e0..2ba3e3c7 100644
--- a/tools/convert_imageset.cpp
+++ b/tools/convert_imageset.cpp
@@ -78,10 +78,11 @@ int main(int argc, char** argv) {
int resize_width = std::max<int>(0, FLAGS_resize_width);
// Open new db
- shared_ptr<Database> database = DatabaseFactory(db_backend);
+ shared_ptr<Database<string, Datum> > database =
+ DatabaseFactory<string, Datum>(db_backend);
// Open db
- CHECK(database->open(db_path, Database::New));
+ CHECK(database->open(db_path, Database<string, Datum>::New));
// Storing to db
std::string root_folder(argv[1]);
@@ -110,13 +111,9 @@ int main(int argc, char** argv) {
// sequential
int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
lines[line_id].first.c_str());
- Database::value_type value(datum.ByteSize());
- datum.SerializeWithCachedSizesToArray(
- reinterpret_cast<unsigned char*>(value.data()));
- Database::key_type keystr(key_cstr, key_cstr + length);
// Put in db
- CHECK(database->put(keystr, value));
+ CHECK(database->put(string(key_cstr, length), datum));
if (++count % 1000 == 0) {
// Commit txn
diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp
index 1340192c..47565a83 100644
--- a/tools/extract_features.cpp
+++ b/tools/extract_features.cpp
@@ -121,11 +121,13 @@ int feature_extraction_pipeline(int argc, char** argv) {
int num_mini_batches = atoi(argv[++arg_pos]);
- std::vector<shared_ptr<Database> > feature_dbs;
+ std::vector<shared_ptr<Database<std::string, Datum> > > feature_dbs;
for (size_t i = 0; i < num_features; ++i) {
LOG(INFO)<< "Opening database " << database_names[i];
- shared_ptr<Database> database = DatabaseFactory(argv[++arg_pos]);
- CHECK(database->open(database_names.at(i), Database::New));
+ shared_ptr<Database<std::string, Datum> > database =
+ DatabaseFactory<std::string, Datum>(argv[++arg_pos]);
+ CHECK(database->open(database_names.at(i),
+ Database<std::string, Datum>::New));
feature_dbs.push_back(database);
}
@@ -155,13 +157,9 @@ int feature_extraction_pipeline(int argc, char** argv) {
for (int d = 0; d < dim_features; ++d) {
datum.add_float_data(feature_blob_data[d]);
}
- Database::value_type value(datum.ByteSize());
- datum.SerializeWithCachedSizesToArray(
- reinterpret_cast<unsigned char*>(value.data()));
int length = snprintf(key_str, kMaxKeyStrLength, "%d",
image_indices[i]);
- Database::key_type key(key_str, key_str + length);
- CHECK(feature_dbs.at(i)->put(key, value));
+ CHECK(feature_dbs.at(i)->put(std::string(key_str, length), datum));
++image_indices[i];
if (image_indices[i] % 1000 == 0) {
CHECK(feature_dbs.at(i)->commit());