summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-06-26 19:13:28 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-06-26 19:13:28 -0700
commit2e23f28f622c5f2346f10c9d6ae21e1b8f44e04b (patch)
treec27e7ff958a87d001cf6f5c803cf4c641396909b /tools
parent3ab8fe7250d7b8ab99f61ab2f4e7d5ca0078f2ad (diff)
parentdda092c84e35943ced8be41fd545aea4ab3c6409 (diff)
downloadcaffeonacl-2e23f28f622c5f2346f10c9d6ae21e1b8f44e04b.tar.gz
caffeonacl-2e23f28f622c5f2346f10c9d6ae21e1b8f44e04b.tar.bz2
caffeonacl-2e23f28f622c5f2346f10c9d6ae21e1b8f44e04b.zip
Merge pull request #511 from kloudkl/extract_multiple_features
Extract multiple features in a single Forward pass
Diffstat (limited to 'tools')
-rw-r--r--tools/extract_features.cpp129
1 files changed, 78 insertions, 51 deletions
diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp
index 3a670d96..22e74908 100644
--- a/tools/extract_features.cpp
+++ b/tools/extract_features.cpp
@@ -5,6 +5,7 @@
#include <google/protobuf/text_format.h>
#include <leveldb/db.h>
#include <leveldb/write_batch.h>
+#include <boost/algorithm/string.hpp>
#include <string>
#include <vector>
@@ -27,14 +28,20 @@ int main(int argc, char** argv) {
template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv) {
+ ::google::InitGoogleLogging(argv[0]);
const int num_required_args = 6;
if (argc < num_required_args) {
LOG(ERROR)<<
"This program takes in a trained network and an input data layer, and then"
" extract features of the input data produced by the net.\n"
- "Usage: demo_extract_features pretrained_net_param"
- " feature_extraction_proto_file extract_feature_blob_name"
- " save_feature_leveldb_name num_mini_batches [CPU/GPU] [DEVICE_ID=0]";
+ "Usage: extract_features pretrained_net_param"
+ " feature_extraction_proto_file extract_feature_blob_name1[,name2,...]"
+ " save_feature_leveldb_name1[,name2,...] num_mini_batches [CPU/GPU]"
+ " [DEVICE_ID=0]\n"
+ "Note: you can extract multiple features in one pass by specifying"
+ " multiple feature blob names and leveldb names seperated by ','."
+ " The names cannot contain white space characters and the number of blobs"
+ " and leveldbs must be equal.";
return 1;
}
int arg_pos = num_required_args;
@@ -91,74 +98,94 @@ int feature_extraction_pipeline(int argc, char** argv) {
new Net<Dtype>(feature_extraction_proto));
feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);
- string extract_feature_blob_name(argv[++arg_pos]);
- CHECK(feature_extraction_net->has_blob(extract_feature_blob_name))
- << "Unknown feature blob name " << extract_feature_blob_name
- << " in the network " << feature_extraction_proto;
+ string extract_feature_blob_names(argv[++arg_pos]);
+ vector<string> blob_names;
+ boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));
+
+ string save_feature_leveldb_names(argv[++arg_pos]);
+ vector<string> leveldb_names;
+ boost::split(leveldb_names, save_feature_leveldb_names,
+ boost::is_any_of(","));
+ CHECK_EQ(blob_names.size(), leveldb_names.size()) <<
+ " the number of blob names and leveldb names must be equal";
+ size_t num_features = blob_names.size();
+
+ for (size_t i = 0; i < num_features; i++) {
+ CHECK(feature_extraction_net->has_blob(blob_names[i]))
+ << "Unknown feature blob name " << blob_names[i]
+ << " in the network " << feature_extraction_proto;
+ }
- string save_feature_leveldb_name(argv[++arg_pos]);
- leveldb::DB* db;
leveldb::Options options;
options.error_if_exists = true;
options.create_if_missing = true;
options.write_buffer_size = 268435456;
- LOG(INFO)<< "Opening leveldb " << save_feature_leveldb_name;
- leveldb::Status status = leveldb::DB::Open(options,
- save_feature_leveldb_name.c_str(),
- &db);
- CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;
+ vector<shared_ptr<leveldb::DB> > feature_dbs;
+ for (size_t i = 0; i < num_features; ++i) {
+ LOG(INFO)<< "Opening leveldb " << leveldb_names[i];
+ leveldb::DB* db;
+ leveldb::Status status = leveldb::DB::Open(options,
+ leveldb_names[i].c_str(),
+ &db);
+ CHECK(status.ok()) << "Failed to open leveldb " << leveldb_names[i];
+ feature_dbs.push_back(shared_ptr<leveldb::DB>(db));
+ }
int num_mini_batches = atoi(argv[++arg_pos]);
LOG(ERROR)<< "Extacting Features";
Datum datum;
- leveldb::WriteBatch* batch = new leveldb::WriteBatch();
+ vector<shared_ptr<leveldb::WriteBatch> > feature_batches(
+ num_features,
+ shared_ptr<leveldb::WriteBatch>(new leveldb::WriteBatch()));
const int kMaxKeyStrLength = 100;
char key_str[kMaxKeyStrLength];
vector<Blob<float>*> input_vec;
- int image_index = 0;
+ vector<int> image_indices(num_features, 0);
for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
feature_extraction_net->Forward(input_vec);
- const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
- ->blob_by_name(extract_feature_blob_name);
- int num_features = feature_blob->num();
- int dim_features = feature_blob->count() / num_features;
- Dtype* feature_blob_data;
- for (int n = 0; n < num_features; ++n) {
- datum.set_height(dim_features);
- datum.set_width(1);
- datum.set_channels(1);
- datum.clear_data();
- datum.clear_float_data();
- feature_blob_data = feature_blob->mutable_cpu_data() +
- feature_blob->offset(n);
- for (int d = 0; d < dim_features; ++d) {
- datum.add_float_data(feature_blob_data[d]);
- }
- string value;
- datum.SerializeToString(&value);
- snprintf(key_str, kMaxKeyStrLength, "%d", image_index);
- batch->Put(string(key_str), value);
- ++image_index;
- if (image_index % 1000 == 0) {
- db->Write(leveldb::WriteOptions(), batch);
- LOG(ERROR)<< "Extracted features of " << image_index <<
- " query images.";
- delete batch;
- batch = new leveldb::WriteBatch();
- }
- } // for (int n = 0; n < num_features; ++n)
+ for (int i = 0; i < num_features; ++i) {
+ const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
+ ->blob_by_name(blob_names[i]);
+ int batch_size = feature_blob->num();
+ int dim_features = feature_blob->count() / batch_size;
+ Dtype* feature_blob_data;
+ for (int n = 0; n < batch_size; ++n) {
+ datum.set_height(dim_features);
+ datum.set_width(1);
+ datum.set_channels(1);
+ datum.clear_data();
+ datum.clear_float_data();
+ feature_blob_data = feature_blob->mutable_cpu_data() +
+ feature_blob->offset(n);
+ for (int d = 0; d < dim_features; ++d) {
+ datum.add_float_data(feature_blob_data[d]);
+ }
+ string value;
+ datum.SerializeToString(&value);
+ snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]);
+ feature_batches[i]->Put(string(key_str), value);
+ ++image_indices[i];
+ if (image_indices[i] % 1000 == 0) {
+ feature_dbs[i]->Write(leveldb::WriteOptions(),
+ feature_batches[i].get());
+ LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
+ " query images for feature blob " << blob_names[i];
+ feature_batches[i].reset(new leveldb::WriteBatch());
+ }
+ } // for (int n = 0; n < batch_size; ++n)
+ } // for (int i = 0; i < num_features; ++i)
} // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
// write the last batch
- if (image_index % 1000 != 0) {
- db->Write(leveldb::WriteOptions(), batch);
- LOG(ERROR)<< "Extracted features of " << image_index <<
- " query images.";
+ for (int i = 0; i < num_features; ++i) {
+ if (image_indices[i] % 1000 != 0) {
+ feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i].get());
+ }
+ LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
+ " query images for feature blob " << blob_names[i];
}
- delete batch;
- delete db;
LOG(ERROR)<< "Successfully extracted the features!";
return 0;
}