diff options
author | Kai Li <kaili_kloud@163.com> | 2014-07-08 22:24:20 +0800 |
---|---|---|
committer | Kai Li <kaili_kloud@163.com> | 2014-07-20 00:26:41 +0800 |
commit | dfe69b2f06d68633a37a99b56b8878e7fcad14e6 (patch) | |
tree | 643d63a2dda1bde414686bbb842ea7309c7bbc88 /src/caffe/layers/argmax_layer.cpp | |
parent | b9a9c588be6286ecc4bc899e20b5f040f8d14ae7 (diff) | |
download | caffeonacl-dfe69b2f06d68633a37a99b56b8878e7fcad14e6.tar.gz caffeonacl-dfe69b2f06d68633a37a99b56b8878e7fcad14e6.tar.bz2 caffeonacl-dfe69b2f06d68633a37a99b56b8878e7fcad14e6.zip |
Simplify the top-k argmax layer using std::sort
Diffstat (limited to 'src/caffe/layers/argmax_layer.cpp')
-rw-r--r-- | src/caffe/layers/argmax_layer.cpp | 38 |
1 files changed, 8 insertions, 30 deletions
diff --git a/src/caffe/layers/argmax_layer.cpp b/src/caffe/layers/argmax_layer.cpp index 1fe2a89e..00c17f1c 100644 --- a/src/caffe/layers/argmax_layer.cpp +++ b/src/caffe/layers/argmax_layer.cpp @@ -1,5 +1,6 @@ // Copyright 2014 BVLC and contributors. +#include <algorithm> #include <cfloat> #include <queue> #include <vector> @@ -28,50 +29,27 @@ void ArgMaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom, } template <typename Dtype> -class IDAndValueComparator { - public: - bool operator() (const std::pair<size_t, Dtype>& lhs, - const std::pair<size_t, Dtype>& rhs) const { - return lhs.second < rhs.second || (lhs.second == rhs.second && - lhs.first < rhs.first); - } -}; - -template <typename Dtype> Dtype ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) { const Dtype* bottom_data = bottom[0]->cpu_data(); Dtype* top_data = (*top)[0]->mutable_cpu_data(); int num = bottom[0]->num(); int dim = bottom[0]->count() / bottom[0]->num(); - Dtype value; for (int i = 0; i < num; ++i) { - std::priority_queue<std::pair<size_t, Dtype>, - std::vector<std::pair<size_t, Dtype> >, IDAndValueComparator<Dtype> > - top_k_results; + std::vector<std::pair<int, Dtype> > bottom_data_vector; for (int j = 0; j < dim; ++j) { - value = -(bottom_data[i * dim + j]); - if (top_k_results.size() >= top_k_) { - if (value < top_k_results.top().second) { - top_k_results.pop(); - top_k_results.push(std::make_pair(j, value)); - } - } else { - top_k_results.push(std::make_pair(j, value)); - } + bottom_data_vector.push_back(std::make_pair(j, bottom_data[i * dim + j])); } + std::sort(bottom_data_vector.begin(), bottom_data_vector.end(), + int_Dtype_pair_greater<Dtype>); if (out_max_val_) { for (int j = 0; j < top_k_; ++j) { - top_data[i * 2 * top_k_ + (top_k_ - 1 - j) * 2] = - top_k_results.top().first; - top_data[i * 2 * top_k_ + (top_k_ - 1 - j) * 2 + 1] = - -(top_k_results.top().second); - top_k_results.pop(); + top_data[(*top)[0]->offset(i, 0, j)] = bottom_data_vector[j].first; + top_data[(*top)[0]->offset(i, 1, j)] = bottom_data_vector[j].second; } } else { for (int j = 0; j < top_k_; ++j) { - top_data[i * top_k_ + (top_k_ - 1 - j)] = top_k_results.top().first; - top_k_results.pop(); + top_data[(*top)[0]->offset(i, 0, j)] = bottom_data_vector[j].first; } } } |