diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-06-26 19:13:28 -0700 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-06-26 19:13:28 -0700 |
commit | 2e23f28f622c5f2346f10c9d6ae21e1b8f44e04b (patch) | |
tree | c27e7ff958a87d001cf6f5c803cf4c641396909b /tools | |
parent | 3ab8fe7250d7b8ab99f61ab2f4e7d5ca0078f2ad (diff) | |
parent | dda092c84e35943ced8be41fd545aea4ab3c6409 (diff) | |
download | caffeonacl-2e23f28f622c5f2346f10c9d6ae21e1b8f44e04b.tar.gz caffeonacl-2e23f28f622c5f2346f10c9d6ae21e1b8f44e04b.tar.bz2 caffeonacl-2e23f28f622c5f2346f10c9d6ae21e1b8f44e04b.zip |
Merge pull request #511 from kloudkl/extract_multiple_features
Extract multiple features in a single Forward pass
Diffstat (limited to 'tools')
-rw-r--r-- | tools/extract_features.cpp | 129 |
1 files changed, 78 insertions, 51 deletions
diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index 3a670d96..22e74908 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -5,6 +5,7 @@ #include <google/protobuf/text_format.h> #include <leveldb/db.h> #include <leveldb/write_batch.h> +#include <boost/algorithm/string.hpp> #include <string> #include <vector> @@ -27,14 +28,20 @@ int main(int argc, char** argv) { template<typename Dtype> int feature_extraction_pipeline(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); const int num_required_args = 6; if (argc < num_required_args) { LOG(ERROR)<< "This program takes in a trained network and an input data layer, and then" " extract features of the input data produced by the net.\n" - "Usage: demo_extract_features pretrained_net_param" - " feature_extraction_proto_file extract_feature_blob_name" - " save_feature_leveldb_name num_mini_batches [CPU/GPU] [DEVICE_ID=0]"; + "Usage: extract_features pretrained_net_param" + " feature_extraction_proto_file extract_feature_blob_name1[,name2,...]" + " save_feature_leveldb_name1[,name2,...] num_mini_batches [CPU/GPU]" + " [DEVICE_ID=0]\n" + "Note: you can extract multiple features in one pass by specifying" + " multiple feature blob names and leveldb names seperated by ','." + " The names cannot contain white space characters and the number of blobs" + " and leveldbs must be equal."; return 1; } int arg_pos = num_required_args; @@ -91,74 +98,94 @@ int feature_extraction_pipeline(int argc, char** argv) { new Net<Dtype>(feature_extraction_proto)); feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto); - string extract_feature_blob_name(argv[++arg_pos]); - CHECK(feature_extraction_net->has_blob(extract_feature_blob_name)) - << "Unknown feature blob name " << extract_feature_blob_name - << " in the network " << feature_extraction_proto; + string extract_feature_blob_names(argv[++arg_pos]); + vector<string> blob_names; + boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(",")); + + string save_feature_leveldb_names(argv[++arg_pos]); + vector<string> leveldb_names; + boost::split(leveldb_names, save_feature_leveldb_names, + boost::is_any_of(",")); + CHECK_EQ(blob_names.size(), leveldb_names.size()) << + " the number of blob names and leveldb names must be equal"; + size_t num_features = blob_names.size(); + + for (size_t i = 0; i < num_features; i++) { + CHECK(feature_extraction_net->has_blob(blob_names[i])) + << "Unknown feature blob name " << blob_names[i] + << " in the network " << feature_extraction_proto; + } - string save_feature_leveldb_name(argv[++arg_pos]); - leveldb::DB* db; leveldb::Options options; options.error_if_exists = true; options.create_if_missing = true; options.write_buffer_size = 268435456; - LOG(INFO)<< "Opening leveldb " << save_feature_leveldb_name; - leveldb::Status status = leveldb::DB::Open(options, - save_feature_leveldb_name.c_str(), - &db); - CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name; + vector<shared_ptr<leveldb::DB> > feature_dbs; + for (size_t i = 0; i < num_features; ++i) { + LOG(INFO)<< "Opening leveldb " << leveldb_names[i]; + leveldb::DB* db; + leveldb::Status status = leveldb::DB::Open(options, + leveldb_names[i].c_str(), + &db); + CHECK(status.ok()) << "Failed to open leveldb " << leveldb_names[i]; + feature_dbs.push_back(shared_ptr<leveldb::DB>(db)); + } int num_mini_batches = atoi(argv[++arg_pos]); LOG(ERROR)<< "Extacting Features"; Datum datum; - leveldb::WriteBatch* batch = new leveldb::WriteBatch(); + vector<shared_ptr<leveldb::WriteBatch> > feature_batches( + num_features, + shared_ptr<leveldb::WriteBatch>(new leveldb::WriteBatch())); const int kMaxKeyStrLength = 100; char key_str[kMaxKeyStrLength]; vector<Blob<float>*> input_vec; - int image_index = 0; + vector<int> image_indices(num_features, 0); for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) { feature_extraction_net->Forward(input_vec); - const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net - ->blob_by_name(extract_feature_blob_name); - int num_features = feature_blob->num(); - int dim_features = feature_blob->count() / num_features; - Dtype* feature_blob_data; - for (int n = 0; n < num_features; ++n) { - datum.set_height(dim_features); - datum.set_width(1); - datum.set_channels(1); - datum.clear_data(); - datum.clear_float_data(); - feature_blob_data = feature_blob->mutable_cpu_data() + - feature_blob->offset(n); - for (int d = 0; d < dim_features; ++d) { - datum.add_float_data(feature_blob_data[d]); - } - string value; - datum.SerializeToString(&value); - snprintf(key_str, kMaxKeyStrLength, "%d", image_index); - batch->Put(string(key_str), value); - ++image_index; - if (image_index % 1000 == 0) { - db->Write(leveldb::WriteOptions(), batch); - LOG(ERROR)<< "Extracted features of " << image_index << - " query images."; - delete batch; - batch = new leveldb::WriteBatch(); - } - } // for (int n = 0; n < num_features; ++n) + for (int i = 0; i < num_features; ++i) { + const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net + ->blob_by_name(blob_names[i]); + int batch_size = feature_blob->num(); + int dim_features = feature_blob->count() / batch_size; + Dtype* feature_blob_data; + for (int n = 0; n < batch_size; ++n) { + datum.set_height(dim_features); + datum.set_width(1); + datum.set_channels(1); + datum.clear_data(); + datum.clear_float_data(); + feature_blob_data = feature_blob->mutable_cpu_data() + + feature_blob->offset(n); + for (int d = 0; d < dim_features; ++d) { + datum.add_float_data(feature_blob_data[d]); + } + string value; + datum.SerializeToString(&value); + snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]); + feature_batches[i]->Put(string(key_str), value); + ++image_indices[i]; + if (image_indices[i] % 1000 == 0) { + feature_dbs[i]->Write(leveldb::WriteOptions(), + feature_batches[i].get()); + LOG(ERROR)<< "Extracted features of " << image_indices[i] << + " query images for feature blob " << blob_names[i]; + feature_batches[i].reset(new leveldb::WriteBatch()); + } + } // for (int n = 0; n < batch_size; ++n) + } // for (int i = 0; i < num_features; ++i) } // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) // write the last batch - if (image_index % 1000 != 0) { - db->Write(leveldb::WriteOptions(), batch); - LOG(ERROR)<< "Extracted features of " << image_index << - " query images."; + for (int i = 0; i < num_features; ++i) { + if (image_indices[i] % 1000 != 0) { + feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i].get()); + } + LOG(ERROR)<< "Extracted features of " << image_indices[i] << + " query images for feature blob " << blob_names[i]; } - delete batch; - delete db; LOG(ERROR)<< "Successfully extracted the features!"; return 0; } |