summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorKai Li <kaili_kloud@163.com>2014-02-26 03:47:32 +0800
committerKai Li <kaili_kloud@163.com>2014-03-19 23:04:42 +0800
commitcfb2f915b9efdab3d8a484ed767a0c2ecfd2af7b (patch)
treee8619ff429dfa7be216892dc63c3c34d570d719c /examples
parent01bb481702243eaa8a07d27df48d9ce1d109ebfa (diff)
downloadcaffe-cfb2f915b9efdab3d8a484ed767a0c2ecfd2af7b.tar.gz
caffe-cfb2f915b9efdab3d8a484ed767a0c2ecfd2af7b.tar.bz2
caffe-cfb2f915b9efdab3d8a484ed767a0c2ecfd2af7b.zip
Fix bugs of the feature binarization example
Diffstat (limited to 'examples')
-rw-r--r--examples/demo_binarize_features.cpp83
1 files changed, 37 insertions, 46 deletions
diff --git a/examples/demo_binarize_features.cpp b/examples/demo_binarize_features.cpp
index 5a13bc2d..9433d2fb 100644
--- a/examples/demo_binarize_features.cpp
+++ b/examples/demo_binarize_features.cpp
@@ -12,6 +12,7 @@
using namespace caffe;
+// TODO: Replace this with caffe_sign after the PR #159 is merged
template<typename Dtype>
inline int sign(const Dtype val) {
return (Dtype(0) < val) - (val < Dtype(0));
@@ -35,12 +36,12 @@ int main(int argc, char** argv) {
template<typename Dtype>
int features_binarization_pipeline(int argc, char** argv) {
- const int num_required_args = 4;
+ const int num_required_args = 5;
if (argc < num_required_args) {
LOG(ERROR)<<
- "This program compresses real valued features into compact binary codes."
- "Usage: demo_binarize_features data_prototxt data_layer_name"
- " save_binarized_feature_binaryproto_file [CPU/GPU] [DEVICE_ID=0]";
+ "This program compresses real valued features into compact binary codes.\n"
+ "Usage: demo_binarize_features real_valued_feature_prototxt feature_blob_name"
+ " save_binarized_feature_binaryproto_file num_mini_batches [CPU/GPU] [DEVICE_ID=0]";
return 1;
}
int arg_pos = num_required_args;
@@ -78,49 +79,38 @@ int features_binarization_pipeline(int argc, char** argv) {
top: "label"
}
*/
- string data_prototxt(argv[++arg_pos]);
- string data_layer_name(argv[++arg_pos]);
- NetParameter data_net_param;
- ReadProtoFromTextFile(data_prototxt.c_str(), &data_net_param);
- LayerParameter data_layer_param;
- int num_layer;
- for (num_layer = 0; num_layer < data_net_param.layers_size(); ++num_layer) {
- if (data_layer_name == data_net_param.layers(num_layer).layer().name()) {
- data_layer_param = data_net_param.layers(num_layer).layer();
- break;
- }
- }
- if (num_layer = data_net_param.layers_size()) {
- LOG(ERROR) << "Unknow data layer name " << data_layer_name <<
- " in prototxt " << data_prototxt;
- }
+ string real_valued_feature_prototxt(argv[++arg_pos]);
+ NetParameter real_valued_feature_net_param;
+ ReadProtoFromTextFile(real_valued_feature_prototxt,
+ &real_valued_feature_net_param);
+ shared_ptr<Net<Dtype> > real_valued_feature_net(
+ new Net<Dtype>(real_valued_feature_net_param));
+
+ string feature_blob_name(argv[++arg_pos]);
+ CHECK(real_valued_feature_net->HasBlob(feature_blob_name))
+ << "Unknown feature blob name " << feature_blob_name << " in the network "
+ << real_valued_feature_prototxt;
string save_binarized_feature_binaryproto_file(argv[++arg_pos]);
+ int num_mini_batches = atoi(argv[++arg_pos]);
+
LOG(ERROR)<< "Binarizing features";
- DataLayer<Dtype> data_layer(data_layer_param);
- vector<Blob<Dtype>*> bottom_vec_that_data_layer_does_not_need_;
- vector<Blob<Dtype>*> top_vec;
- data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec);
- shared_ptr<Blob<Dtype> > feature_binary_codes;
+ vector<Blob<Dtype>*> input_vec;
+ shared_ptr<Blob<Dtype> > feature_binary_codes(new Blob<Dtype>());
BlobProtoVector blob_proto_vector;
- int batch_index = 0;
- // TODO: DataLayer seem to rotate from the last record to the first
- // how to judge that all the data record have been enumerated?
- while (top_vec.size()) { // data_layer still outputs data
- LOG(ERROR)<< "Batch " << batch_index << " feature binarization";
- const shared_ptr<Blob<Dtype> > feature_blob(top_vec[0]);
+ int num_features = 0;
+ for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
+ real_valued_feature_net->Forward(input_vec);
+ const shared_ptr<Blob<Dtype> > feature_blob = real_valued_feature_net
+ ->GetBlob(feature_blob_name);
binarize<Dtype>(feature_blob, feature_binary_codes);
-
- LOG(ERROR) << "Batch " << batch_index << " save binarized features";
+ num_features += feature_binary_codes->num();
feature_binary_codes->ToProto(blob_proto_vector.add_blobs());
-
- data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec);
- ++batch_index;
- } // while (top_vec.size()) {
-
- WriteProtoToBinaryFile(blob_proto_vector, save_binarized_feature_binaryproto_file);
- LOG(ERROR)<< "Successfully ended!";
+ } // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
+ WriteProtoToBinaryFile(blob_proto_vector,
+ save_binarized_feature_binaryproto_file);
+ LOG(ERROR)<< "Successfully binarized " << num_features << " features!";
return 0;
}
@@ -133,17 +123,19 @@ void binarize(const int n, const Dtype* real_valued_feature,
// In IEEE International Conference on Computer Vision and Pattern Recognition (CVPR), 2013.
// http://www.unc.edu/~yunchao/bpbc.htm
int size_of_code = sizeof(Dtype) * 8;
- CHECK_EQ(n % size_of_code, 0);
- int num_binary_codes = n / size_of_code;
+ int num_binary_codes = (n + size_of_code - 1) / size_of_code;
uint64_t code;
int offset;
+ int count = 0;
for (int i = 0; i < num_binary_codes; ++i) {
- code = 0;
offset = i * size_of_code;
- for (int j = 0; j < size_of_code; ++j) {
+ int j = 0;
+ code = 0;
+ for (; j < size_of_code && count++ < n; ++j) {
code |= sign(real_valued_feature[offset + j]);
code << 1;
}
+ code << (size_of_code - j);
binary_codes[i] = static_cast<Dtype>(code);
}
}
@@ -154,8 +146,7 @@ void binarize(const shared_ptr<Blob<Dtype> > real_valued_features,
int num = real_valued_features->num();
int dim = real_valued_features->count() / num;
int size_of_code = sizeof(Dtype) * 8;
- CHECK_EQ(dim % size_of_code, 0);
- binary_codes->Reshape(num, dim / size_of_code, 1, 1);
+ binary_codes->Reshape(num, (dim + size_of_code - 1) / size_of_code, 1, 1);
const Dtype* real_valued_features_data = real_valued_features->cpu_data();
Dtype* binary_codes_data = binary_codes->mutable_cpu_data();
for (int n = 0; n < num; ++n) {