summaryrefslogtreecommitdiff
path: root/src/caffe/layers/argmax_layer.cpp
diff options
context:
space:
mode:
authorKai Li <kaili_kloud@163.com>2014-07-04 17:19:24 +0800
committerKai Li <kaili_kloud@163.com>2014-07-20 00:26:40 +0800
commit7722514cddc127ba782ad97732532e6afea0db6e (patch)
treedc87011ad5497e7ed5a856a2d7981518ae4df5a0 /src/caffe/layers/argmax_layer.cpp
parent61bd7ea32470684201b0a8a9d3edea14edb302f2 (diff)
downloadcaffeonacl-7722514cddc127ba782ad97732532e6afea0db6e.tar.gz
caffeonacl-7722514cddc127ba782ad97732532e6afea0db6e.tar.bz2
caffeonacl-7722514cddc127ba782ad97732532e6afea0db6e.zip
Extend the ArgMaxLayer to output top k results
Diffstat (limited to 'src/caffe/layers/argmax_layer.cpp')
-rw-r--r--src/caffe/layers/argmax_layer.cpp54
1 files changed, 41 insertions, 13 deletions
diff --git a/src/caffe/layers/argmax_layer.cpp b/src/caffe/layers/argmax_layer.cpp
index cc31c0f5..1fe2a89e 100644
--- a/src/caffe/layers/argmax_layer.cpp
+++ b/src/caffe/layers/argmax_layer.cpp
@@ -1,12 +1,12 @@
// Copyright 2014 BVLC and contributors.
-#include <vector>
#include <cfloat>
+#include <queue>
+#include <vector>
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
-
namespace caffe {
template <typename Dtype>
@@ -14,36 +14,65 @@ void ArgMaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Layer<Dtype>::SetUp(bottom, top);
out_max_val_ = this->layer_param_.argmax_param().out_max_val();
+ top_k_ = this->layer_param_.argmax_param().top_k();
+ CHECK_GE(top_k_, 1) << " top k must not be less than 1.";
+ CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num())
+ << "top_k must be less than or equal to the number of classes.";
if (out_max_val_) {
// Produces max_ind and max_val
- (*top)[0]->Reshape(bottom[0]->num(), 2, 1, 1);
+ (*top)[0]->Reshape(bottom[0]->num(), 2, top_k_, 1);
} else {
// Produces only max_ind
- (*top)[0]->Reshape(bottom[0]->num(), 1, 1, 1);
+ (*top)[0]->Reshape(bottom[0]->num(), 1, top_k_, 1);
}
}
template <typename Dtype>
+class IDAndValueComparator {
+ public:
+ bool operator() (const std::pair<size_t, Dtype>& lhs,
+ const std::pair<size_t, Dtype>& rhs) const {
+ return lhs.second < rhs.second || (lhs.second == rhs.second &&
+ lhs.first < rhs.first);
+ }
+};
+
+template <typename Dtype>
Dtype ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
int num = bottom[0]->num();
int dim = bottom[0]->count() / bottom[0]->num();
+ Dtype value;
for (int i = 0; i < num; ++i) {
- Dtype max_val = -FLT_MAX;
- int max_ind = 0;
+ std::priority_queue<std::pair<size_t, Dtype>,
+ std::vector<std::pair<size_t, Dtype> >, IDAndValueComparator<Dtype> >
+ top_k_results;
for (int j = 0; j < dim; ++j) {
- if (bottom_data[i * dim + j] > max_val) {
- max_val = bottom_data[i * dim + j];
- max_ind = j;
+ value = -(bottom_data[i * dim + j]);
+ if (top_k_results.size() >= top_k_) {
+ if (value < top_k_results.top().second) {
+ top_k_results.pop();
+ top_k_results.push(std::make_pair(j, value));
+ }
+ } else {
+ top_k_results.push(std::make_pair(j, value));
}
}
if (out_max_val_) {
- top_data[i * 2] = max_ind;
- top_data[i * 2 + 1] = max_val;
+ for (int j = 0; j < top_k_; ++j) {
+ top_data[i * 2 * top_k_ + (top_k_ - 1 - j) * 2] =
+ top_k_results.top().first;
+ top_data[i * 2 * top_k_ + (top_k_ - 1 - j) * 2 + 1] =
+ -(top_k_results.top().second);
+ top_k_results.pop();
+ }
} else {
- top_data[i] = max_ind;
+ for (int j = 0; j < top_k_; ++j) {
+ top_data[i * top_k_ + (top_k_ - 1 - j)] = top_k_results.top().first;
+ top_k_results.pop();
+ }
}
}
return Dtype(0);
@@ -51,5 +80,4 @@ Dtype ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
INSTANTIATE_CLASS(ArgMaxLayer);
-
} // namespace caffe