summaryrefslogtreecommitdiff
path: root/src/caffe/layers/argmax_layer.cpp
diff options
context:
space:
mode:
authorKai Li <kaili_kloud@163.com>2014-07-08 22:24:20 +0800
committerKai Li <kaili_kloud@163.com>2014-07-20 00:26:41 +0800
commitdfe69b2f06d68633a37a99b56b8878e7fcad14e6 (patch)
tree643d63a2dda1bde414686bbb842ea7309c7bbc88 /src/caffe/layers/argmax_layer.cpp
parentb9a9c588be6286ecc4bc899e20b5f040f8d14ae7 (diff)
downloadcaffeonacl-dfe69b2f06d68633a37a99b56b8878e7fcad14e6.tar.gz
caffeonacl-dfe69b2f06d68633a37a99b56b8878e7fcad14e6.tar.bz2
caffeonacl-dfe69b2f06d68633a37a99b56b8878e7fcad14e6.zip
Simplify the top-k argmax layer using std::sort
Diffstat (limited to 'src/caffe/layers/argmax_layer.cpp')
-rw-r--r--src/caffe/layers/argmax_layer.cpp38
1 files changed, 8 insertions, 30 deletions
diff --git a/src/caffe/layers/argmax_layer.cpp b/src/caffe/layers/argmax_layer.cpp
index 1fe2a89e..00c17f1c 100644
--- a/src/caffe/layers/argmax_layer.cpp
+++ b/src/caffe/layers/argmax_layer.cpp
@@ -1,5 +1,6 @@
// Copyright 2014 BVLC and contributors.
+#include <algorithm>
#include <cfloat>
#include <queue>
#include <vector>
@@ -28,50 +29,27 @@ void ArgMaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-class IDAndValueComparator {
- public:
- bool operator() (const std::pair<size_t, Dtype>& lhs,
- const std::pair<size_t, Dtype>& rhs) const {
- return lhs.second < rhs.second || (lhs.second == rhs.second &&
- lhs.first < rhs.first);
- }
-};
-
-template <typename Dtype>
Dtype ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
int num = bottom[0]->num();
int dim = bottom[0]->count() / bottom[0]->num();
- Dtype value;
for (int i = 0; i < num; ++i) {
- std::priority_queue<std::pair<size_t, Dtype>,
- std::vector<std::pair<size_t, Dtype> >, IDAndValueComparator<Dtype> >
- top_k_results;
+ std::vector<std::pair<int, Dtype> > bottom_data_vector;
for (int j = 0; j < dim; ++j) {
- value = -(bottom_data[i * dim + j]);
- if (top_k_results.size() >= top_k_) {
- if (value < top_k_results.top().second) {
- top_k_results.pop();
- top_k_results.push(std::make_pair(j, value));
- }
- } else {
- top_k_results.push(std::make_pair(j, value));
- }
+ bottom_data_vector.push_back(std::make_pair(j, bottom_data[i * dim + j]));
}
+ std::sort(bottom_data_vector.begin(), bottom_data_vector.end(),
+ int_Dtype_pair_greater<Dtype>);
if (out_max_val_) {
for (int j = 0; j < top_k_; ++j) {
- top_data[i * 2 * top_k_ + (top_k_ - 1 - j) * 2] =
- top_k_results.top().first;
- top_data[i * 2 * top_k_ + (top_k_ - 1 - j) * 2 + 1] =
- -(top_k_results.top().second);
- top_k_results.pop();
+ top_data[(*top)[0]->offset(i, 0, j)] = bottom_data_vector[j].first;
+ top_data[(*top)[0]->offset(i, 1, j)] = bottom_data_vector[j].second;
}
} else {
for (int j = 0; j < top_k_; ++j) {
- top_data[i * top_k_ + (top_k_ - 1 - j)] = top_k_results.top().first;
- top_k_results.pop();
+ top_data[(*top)[0]->offset(i, 0, j)] = bottom_data_vector[j].first;
}
}
}