diff options
author | Kevin James Matzen <kmatzen@cs.cornell.edu> | 2014-10-13 13:16:04 -0400 |
---|---|---|
committer | Kevin James Matzen <kmatzen@cs.cornell.edu> | 2014-10-14 19:35:23 -0400 |
commit | 08b971feae8551ca5c7ce31a938a1f232ee56af2 (patch) | |
tree | b8569b89b09583b4434b0d700ee2943f24b37075 /tools | |
parent | 0987c72c9d11b94524e1dc91150daf273c5b2537 (diff) | |
download | caffeonacl-08b971feae8551ca5c7ce31a938a1f232ee56af2.tar.gz caffeonacl-08b971feae8551ca5c7ce31a938a1f232ee56af2.tar.bz2 caffeonacl-08b971feae8551ca5c7ce31a938a1f232ee56af2.zip |
Templated the key and value types for the Database interface. The Database is now responsible for serialization. Refactored the tests so that they reuse the same code for each value type and backend configuration.
Diffstat (limited to 'tools')
-rw-r--r-- | tools/compute_image_mean.cpp | 16 | ||||
-rw-r--r-- | tools/convert_imageset.cpp | 11 | ||||
-rw-r--r-- | tools/extract_features.cpp | 14 |
3 files changed, 17 insertions, 24 deletions
diff --git a/tools/compute_image_mean.cpp b/tools/compute_image_mean.cpp index d13c4a0f..f1a79679 100644 --- a/tools/compute_image_mean.cpp +++ b/tools/compute_image_mean.cpp @@ -26,18 +26,17 @@ int main(int argc, char** argv) { db_backend = std::string(argv[3]); } - caffe::shared_ptr<Database> database = caffe::DatabaseFactory(db_backend); + caffe::shared_ptr<Database<std::string, Datum> > database = + caffe::DatabaseFactory<std::string, Datum>(db_backend); // Open db - CHECK(database->open(argv[1], Database::ReadOnly)); + CHECK(database->open(argv[1], Database<std::string, Datum>::ReadOnly)); - Datum datum; BlobProto sum_blob; int count = 0; // load first datum - Database::const_iterator iter = database->begin(); - const Database::value_type& first_blob = iter->value; - datum.ParseFromArray(first_blob.data(), first_blob.size()); + Database<std::string, Datum>::const_iterator iter = database->begin(); + const Datum& datum = iter->value; sum_blob.set_num(1); sum_blob.set_channels(datum.channels()); @@ -50,11 +49,10 @@ int main(int argc, char** argv) { sum_blob.add_data(0.); } LOG(INFO) << "Starting Iteration"; - for (Database::const_iterator iter = database->begin(); + for (Database<std::string, Datum>::const_iterator iter = database->begin(); iter != database->end(); ++iter) { // just a dummy operation - const Database::value_type& blob = iter->value; - datum.ParseFromArray(blob.data(), blob.size()); + const Datum& datum = iter->value; const std::string& data = datum.data(); size_in_datum = std::max<int>(datum.data().size(), datum.float_data_size()); diff --git a/tools/convert_imageset.cpp b/tools/convert_imageset.cpp index 1cdca7e0..2ba3e3c7 100644 --- a/tools/convert_imageset.cpp +++ b/tools/convert_imageset.cpp @@ -78,10 +78,11 @@ int main(int argc, char** argv) { int resize_width = std::max<int>(0, FLAGS_resize_width); // Open new db - shared_ptr<Database> database = DatabaseFactory(db_backend); + shared_ptr<Database<string, Datum> > database = + DatabaseFactory<string, Datum>(db_backend); // Open db - CHECK(database->open(db_path, Database::New)); + CHECK(database->open(db_path, Database<string, Datum>::New)); // Storing to db std::string root_folder(argv[1]); @@ -110,13 +111,9 @@ int main(int argc, char** argv) { // sequential int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id, lines[line_id].first.c_str()); - Database::value_type value(datum.ByteSize()); - datum.SerializeWithCachedSizesToArray( - reinterpret_cast<unsigned char*>(value.data())); - Database::key_type keystr(key_cstr, key_cstr + length); // Put in db - CHECK(database->put(keystr, value)); + CHECK(database->put(string(key_cstr, length), datum)); if (++count % 1000 == 0) { // Commit txn diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index 1340192c..47565a83 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -121,11 +121,13 @@ int feature_extraction_pipeline(int argc, char** argv) { int num_mini_batches = atoi(argv[++arg_pos]); - std::vector<shared_ptr<Database> > feature_dbs; + std::vector<shared_ptr<Database<std::string, Datum> > > feature_dbs; for (size_t i = 0; i < num_features; ++i) { LOG(INFO)<< "Opening database " << database_names[i]; - shared_ptr<Database> database = DatabaseFactory(argv[++arg_pos]); - CHECK(database->open(database_names.at(i), Database::New)); + shared_ptr<Database<std::string, Datum> > database = + DatabaseFactory<std::string, Datum>(argv[++arg_pos]); + CHECK(database->open(database_names.at(i), + Database<std::string, Datum>::New)); feature_dbs.push_back(database); } @@ -155,13 +157,9 @@ int feature_extraction_pipeline(int argc, char** argv) { for (int d = 0; d < dim_features; ++d) { datum.add_float_data(feature_blob_data[d]); } - Database::value_type value(datum.ByteSize()); - datum.SerializeWithCachedSizesToArray( - reinterpret_cast<unsigned char*>(value.data())); int length = snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]); - Database::key_type key(key_str, key_str + length); - CHECK(feature_dbs.at(i)->put(key, value)); + CHECK(feature_dbs.at(i)->put(std::string(key_str, length), datum)); ++image_indices[i]; if (image_indices[i] % 1000 == 0) { CHECK(feature_dbs.at(i)->commit()); |